Skip to content

Commit 863e705

Browse files
authored
[LLVMGPU] Support masked contraction in operand upcasting (#19972)
Currently, there is operands of contraction upcasting that happens in LLVMGPUVectorLowering pass. This commit adds support if its was masked where the upcasting should happen outside of the masking op. Signed-off-by: Manupa Karunaratne <[email protected]>
1 parent 32cfabf commit 863e705

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorLowering.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ namespace mlir::iree_compiler {
2626
namespace {
2727

2828
struct PromoteContractOperands final
29-
: public OpRewritePattern<vector::ContractionOp> {
30-
using OpRewritePattern::OpRewritePattern;
29+
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
30+
using MaskableOpRewritePattern::MaskableOpRewritePattern;
3131

32-
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
33-
PatternRewriter &rewriter) const override {
32+
FailureOr<Value>
33+
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
34+
vector::MaskingOpInterface maskOp,
35+
PatternRewriter &rewriter) const override {
3436
Type operandElType = getElementTypeOrSelf(contractOp.getLhsType());
3537
Type resultElType = getElementTypeOrSelf(contractOp.getResultType());
3638

@@ -44,11 +46,16 @@ struct PromoteContractOperands final
4446
Value rhs =
4547
promoteToElementType(loc, rewriter, contractOp.getRhs(), resultElType);
4648

47-
rewriter.replaceOpWithNewOp<vector::ContractionOp>(
48-
contractOp, lhs, rhs, contractOp.getAcc(), contractOp.getIndexingMaps(),
49+
auto replacement = rewriter.create<vector::ContractionOp>(
50+
loc, lhs, rhs, contractOp.getAcc(), contractOp.getIndexingMaps(),
4951
contractOp.getIteratorTypes());
5052

51-
return success();
53+
if (!maskOp) {
54+
return replacement.getResult();
55+
}
56+
auto maskedOp = vector::maskOperation(
57+
rewriter, replacement, maskOp.getMask(), maskOp.getPassthru());
58+
return maskedOp->getResult(0);
5259
}
5360

5461
Value promoteToElementType(Location loc, RewriterBase &rewriter, Value v,

compiler/src/iree/compiler/Codegen/LLVMGPU/test/vector_lowering.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,24 @@ module {
1616
// CHECK: %[[SPLAT:.+]] = vector.splat %[[ELEM]] : vector<8xf16>
1717
// CHECK: %[[INSERT:.+]] = vector.broadcast %[[SPLAT]] : vector<8xf16> to vector<1x8xf16>
1818
// CHECK: return %[[INSERT]]
19+
20+
// -----
21+
22+
module {
23+
func.func @contraction_masked(%lhs: vector<3xf16>, %rhs: vector<2x3xf16>, %acc: vector<2xf32>, %mask: vector<3x2xi1>) -> vector<2xf32> {
24+
%ret = vector.mask %mask { vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1)>], iterator_types = ["reduction", "parallel"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<3xf16>, vector<2x3xf16> into vector<2xf32> } : vector<3x2xi1> -> vector<2xf32>
25+
return %ret: vector<2xf32>
26+
}
27+
}
28+
29+
// CHECK-LABEL: func.func @contraction_masked
30+
// CHECK-SAME: %[[LHS:.+]]: vector<3xf16>, %[[RHS:.+]]: vector<2x3xf16>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<3x2xi1>
31+
// CHECK: %[[TPRHS:.+]] = vector.transpose %[[RHS]], [1, 0] : vector<2x3xf16> to vector<3x2xf16>
32+
// CHECK: %[[RHS_EXTRACT:.+]] = vector.extract %[[TPRHS]][0] : vector<2xf16> from vector<3x2xf16>
33+
// CHECK: %[[LHS_EXTRACT:.+]] = vector.extract %[[LHS]][0] : f16 from vector<3xf16>
34+
// CHECK: %[[RHS_CAST:.+]] = arith.extf %[[RHS_EXTRACT]] : vector<2xf16> to vector<2xf32>
35+
// CHECK: %[[LHS_CAST:.+]] = arith.extf %[[LHS_EXTRACT]] : f16 to f32
36+
// CHECK: %[[MASK_EXTRACT:.+]] = vector.extract %[[MASK]][0] : vector<2xi1> from vector<3x2xi1>
37+
// CHECK: %[[LHS_SPLAT:.+]] = vector.splat %[[LHS_CAST]] : vector<2xf32>
38+
// CHECK: %[[FMA:.+]] = vector.fma %[[RHS_CAST]], %[[LHS_SPLAT]], %[[ACC]] : vector<2xf32>
39+
// CHECK: arith.select %[[MASK_EXTRACT]], %[[FMA]], %[[ACC]] : vector<2xi1>, vector<2xf32>

0 commit comments

Comments
 (0)