-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Linalg] Lower vector.contract to chain of vector.fma for batch reduce matmul #121885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@rolfmorel - Request to add to the reviewer list. |
rolfmorel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per the below, I haven't gone through it all.
-
As in the other PR, move all the files to Vector dialect dir.
-
Can we compose this transform out of existing transforms (with only slight amendments/modifications)? That is, would something like the following work?
unroll vector.contract (e.g. using UnrollContractionPattern)
| |
| v
| hoist transfer and strided_slice ops (using replaceWithAdditionalYields maybe? Or hoist_loop_invariant_subsets)
| |
| v
| lower extracted_slices to vector.load
v
lower vector.contract to vector.fma (e.g. `transform.apply_patterns.vector.lower_contraction lowering_strategy = parallelarith`)
With clean-up where necessary. Most of these are exposed as transform ops, though maybe not everything.
Let's have a chat offline about the above.
| /// suffices for achieving the sum. | ||
| void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); | ||
|
|
||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Surrounding declarations only have one blank line in between.
| %c0 = arith.constant 0 : index | ||
| %0 = memref.get_global @__constant_24x64x64xf32 : memref<24x64x64xf32> | ||
| %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x24x32x64xf32> | ||
| scf.forall (%arg1, %arg2) in (8, 24) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This and the next two scf.fors are extraneous. Please minimize your unit tests to what's necessary for the transform to be tested.
This transformation pass lowers
vector.contractoperation (batch reduce matmul) for GEMM of size MxN to sequence of vector FMAs.