Skip to content

Commit f017bcb

Browse files
authored
[mlir][gpu][spirv] Add conversion for gpu.subgroup_mma_elementwise mulf (#158832)
gpu.subgroup_mma_elementwise supports mulf op type. Add conversion for it.
1 parent beb6bab commit f017bcb

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ static bool createElementwiseOp(ConversionPatternRewriter &builder,
5555
case gpu::MMAElementwiseOp::SUBI:
5656
builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
5757
return true;
58+
case gpu::MMAElementwiseOp::MULF:
59+
builder.replaceOpWithNewOp<spirv::FMulOp>(op, coopType, operands);
60+
return true;
5861
case gpu::MMAElementwiseOp::DIVF:
5962
builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
6063
return true;

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,17 @@ module attributes {
136136
// CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
137137
%E = gpu.subgroup_mma_elementwise divf %D, %A :
138138
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
139+
// CHECK: {{%.*}} = spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
140+
%F = gpu.subgroup_mma_elementwise mulf %E, %A :
141+
(!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
139142
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} :
140143
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> to !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
141-
%F = gpu.subgroup_mma_elementwise extf %E :
144+
%G = gpu.subgroup_mma_elementwise extf %F :
142145
(!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
143146

144147
%i = arith.constant 0 : index
145148
// CHECK: spirv.KHR.CooperativeMatrixStore %{{.+}}, %{{.+}}, %{{.+}}, <RowMajor>
146-
gpu.subgroup_mma_store_matrix %F, %ptr[%i,%i] {leadDimension = 32 : index} :
149+
gpu.subgroup_mma_store_matrix %G, %ptr[%i,%i] {leadDimension = 32 : index} :
147150
!gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
148151
// CHECK: spirv.Return
149152
gpu.return

0 commit comments

Comments
 (0)