Skip to content

Commit 46ad540

Browse files
authored
[mlir][gpu][vector] Lower Vector dialect to GPU for element-wise ops only (#159091)
Current convertVectorToMMAOps starts from vector.contract and finds its dependencies as the targets to convert. In GPU dialect, we have gpu.subgroup_mma_elementwise operation. We should be able to lower element-wise operations to GPU MMA operations without vector.contract. This patch adds this case to the pattern.
1 parent 75469bb commit 46ad540

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,14 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
355355
forwardSliceOptions.filter = hasVectorSrc;
356356

357357
SetVector<Operation *> opToConvert;
358-
op->walk([&](vector::ContractionOp contract) {
359-
if (opToConvert.contains(contract.getOperation()))
358+
op->walk([&](Operation *nestedOp) {
359+
if (!isa<vector::ContractionOp>(nestedOp) &&
360+
!elementwiseSupportsMMAMatrixType(nestedOp))
361+
return;
362+
if (opToConvert.contains(nestedOp))
360363
return;
361364
SetVector<Operation *> dependentOps =
362-
getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
365+
getSliceContract(nestedOp, backwardSliceOptions, forwardSliceOptions);
363366
// If any instruction cannot use MMA matrix type drop the whole
364367
// chain. MMA matrix are stored in an opaque type so they cannot be used
365368
// by all operations.

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,22 @@ func.func @test_unsupported(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg
536536
%0, %1, %arg2 : vector<4x4xi64>, vector<4x4xi64> into vector<4x4xi64>
537537
return %2 : vector<4x4xi64>
538538
}
539+
540+
// -----
541+
542+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
543+
544+
// CHECK-LABEL: func @addf
545+
// CHECK: %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
546+
// CHECK: %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
547+
// CHECK: %[[C:.+]] = gpu.subgroup_mma_elementwise addf %[[A]], %[[B]] : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
548+
// CHECK: gpu.subgroup_mma_store_matrix %[[C]]
549+
func.func @addf(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>) {
550+
%c0 = arith.constant 0 : index
551+
%cst = arith.constant 0.000000e+00 : f16
552+
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
553+
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
554+
%C = arith.addf %A, %B : vector<16x16xf16>
555+
vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
556+
return
557+
}

0 commit comments

Comments
 (0)