-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][x86vector] Lower vector.contract to FMA or packed type dot-product #168074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
|
This transform pass is intended to use at the last stage of vectorization after unrolling of |
|
cc: @rengolin @shahidact please have a look. |
|
cc: @rolfmorel |
rengolin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a short description somewhere of what shape the input is expected to be?
Also, potentially, what upstream transforms may produce it from Linalg.
| if (!lhsTy.getElementType().isF32()) | ||
| return rewriter.notifyMatchFailure(contractOp, | ||
| "Only F32 lowering is supported."); | ||
| if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; })) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In theory, this would also work with B as the broadcasted element, no?
| VectorType rhsTy = contractOp.getRhsType(); | ||
| ArrayRef<int64_t> rhsShape = rhsTy.getShape(); | ||
| llvm::SmallVector<int64_t> dimsRhs; | ||
| llvm::copy_if(rhsShape, std::back_inserter(dimsRhs), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't this be count(rhsShape, 1) == rank(rhsShape) - 1?
Would avoid allocation and copies.
| assert(accTy && "Invalid accumulator"); | ||
| ArrayRef<int64_t> accShape = accTy.getShape(); | ||
| llvm::SmallVector<int64_t> dimsAcc; | ||
| llvm::copy_if(accShape, std::back_inserter(dimsAcc), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be just accShape == rhsShape?
| return rewriter.notifyMatchFailure( | ||
| contractOp, "The LHS vnni dim should be 4 for Int8."); | ||
| llvm::SmallVector<int64_t> dimsLhs; | ||
| llvm::copy_if(lhsShape, std::back_inserter(dimsLhs), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same arguments as above
A
transformpass to lowervector.contractto (a)vector.fmaforF32, (b)x86vector.avx512.dotforBF16, (c)x86vector.avx.dot.i8forInt8packed types.The lowering works on condition with
m,batch,kdims to beoneandvnnidim should be2forbf16;4forint8.