Skip to content

Conversation

@arun-thmn
Copy link
Contributor

@arun-thmn arun-thmn commented Nov 14, 2025

A transform pass to lower vector.contract to (a) vector.fma for F32, (b) x86vector.avx512.dot for BF16, (c) x86vector.avx.dot.i8 for Int8 packed types.

The lowering works on condition with m, batch, k dims to be one and vnni dim should be 2 for bf16; 4 for int8.

@github-actions
Copy link

github-actions bot commented Nov 14, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@arun-thmn
Copy link
Contributor Author

This transform pass is intended to use at the last stage of vectorization after unrolling of vector.contract to unit dims.
This pass is drafted based on the @adam-smnk work in TPP-MLIR PR:1063 (convert vc to fma). Did steal few of test-cases from the patch :). Adam will be added as co-author at the time of merge.

@arun-thmn
Copy link
Contributor Author

cc: @rengolin @shahidact please have a look.

@arun-thmn
Copy link
Contributor Author

cc: @rolfmorel

@rengolin rengolin requested a review from adam-smnk November 14, 2025 16:40
Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a short description somewhere of what shape the input is expected to be?

Also, potentially, what upstream transforms may produce it from Linalg.

if (!lhsTy.getElementType().isF32())
return rewriter.notifyMatchFailure(contractOp,
"Only F32 lowering is supported.");
if (llvm::any_of(lhsTy.getShape(), [](int64_t dim) { return dim != 1; }))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, this would also work with B as the broadcasted element, no?

VectorType rhsTy = contractOp.getRhsType();
ArrayRef<int64_t> rhsShape = rhsTy.getShape();
llvm::SmallVector<int64_t> dimsRhs;
llvm::copy_if(rhsShape, std::back_inserter(dimsRhs),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this be count(rhsShape, 1) == rank(rhsShape) - 1?

Would avoid allocation and copies.

assert(accTy && "Invalid accumulator");
ArrayRef<int64_t> accShape = accTy.getShape();
llvm::SmallVector<int64_t> dimsAcc;
llvm::copy_if(accShape, std::back_inserter(dimsAcc),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be just accShape == rhsShape?

return rewriter.notifyMatchFailure(
contractOp, "The LHS vnni dim should be 4 for Int8.");
llvm::SmallVector<int64_t> dimsLhs;
llvm::copy_if(lhsShape, std::back_inserter(dimsLhs),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same arguments as above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants