-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][Vector] Support mixed mode vector.contract lowering #117753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThis patch adds mixed-mode contract support. The implementation follows the documentation of vector.contract: https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop > If operands and the result have types of different bitwidths, operands are promoted to have the same bitwidth as the result before performing the contraction. For integer types, only signless integer types are supported, and the promotion happens via sign extension. Full diff: https://github.com/llvm/llvm-project/pull/117753.diff 3 Files Affected:
|
cee8bec to
466676c
Compare
|
|
||
| if (isa<FloatType>(dstElementType)) | ||
| return rewriter.create<arith::ExtFOp>(loc, promotedType, v); | ||
| // For integer types, vector.contract only supports signless integer types |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I'd add a new line before this comment
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks a lot for the contribution. As you may know, this has been a controversial topic (actually, I'm not sure if we discussed it in the context of vector.contract or linalg at the time). The main problem is this:
For integer types, only signless integer types are supported, and the promotion happens via sign extension.
If we want to walk this path, we should find a consistent way to represent signed and unsigned extensions per operand (and any potential flags associated with each operand extension, including FP ones). Otherwise, we would need different matches/handling depending on whether the specific extension each operand is going through.
Any ideas to address this issue?
Just to be clear, this patch is implementing mixed mode vector.contract lowering for lowering variants other than outer product. Outer product lowering already implemented mixed mode lowering, this PR simply adds it for other variants as well, in the same way outer product lowering did it. If something is controversial here, I'm only following what is already implemented. On the point of signed/unsigned integer extensions per operand, the line that you quoted regarding using sign extension is not something I decided, but is from vector.contract documentation (https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop):
The documentation already mentions that we use signed extension, which is what this pr implements. Regarding FP types, if an operand extension requires fast math flags, I would guess that this should be annotated on vector.contract. But I'm not aware of any other lowerings looking at fast math flags or any fp flags. Do you have an example or a documentation link I could follow for what is the intended behavior there? |
|
Sorry if my comment reads critical. It was not the intent. This is a recurring issue, and I appreciate you bringing it up! Yet another half-baked thing we have in the Vector dialect and it’s probably a good time to address it. The main challenge lies in embedding conversion semantics for each operand within
What do you think? |
Ok, thanks for making it clear with the full picture. This makes sense, it does seem half-baked. I think that option 2 might be too big of change to plumb through, as many transformations need it. I know there are some mixed precision fadd intrinsics that use vector.contract as a way to target it, and might cause a lot of churn. How about instead we do Option 1) in 2 parts:
What do you think? |
|
It sounds great to me! |
Cool! Let me send a patch tommorow to update vector.contract documentation as discussed. |
This patch adds mixed-mode contract support. The implementation follows the documentation of vector.contract:
https://mlir.llvm.org/docs/Dialects/Vector/#vectorcontract-vectorcontractionop