From 9b814e2bf72004ac5b829a6a18e4f30057e0d2e4 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Fri, 2 May 2025 14:32:50 +0100 Subject: [PATCH 1/4] [mlir][gpu] Add GPU subgroup MMA extract and insert operations - Introduced `gpu.subgroup_mma_extract` operation to extract values from `!gpu.mma_matrix` by invocation and indices. - Introduced `gpu.subgroup_mma_insert` operation to insert values into `!gpu.mma_matrix` by invocation and indices. - Updated the conversion patterns to SPIR-V for both extract and insert operations. - Added test cases to validate the new operations in the GPU to SPIR-V conversion. --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 73 +++++++++++++++++++ .../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 63 ++++++++++++++++ .../wmma-ops-to-spirv-khr-coop-matrix.mlir | 27 +++++++ 3 files changed, 163 insertions(+) diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 68095b7bf5c59..cb363b501851b 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1919,6 +1919,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix", }]; } +def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract", + [Pure, + TypesMatchWith<"value type matches element type of mma_matrix", + "matrix", "res", + "::llvm::cast($_self).getElementType()">]>{ + + let summary = "Extract a value from GPU warp by invocation and indices"; + + let description = [{ + The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix` + by the invocation in a subgroup. + + This operation takes `!gpu.mma_matrix` as its first operand. It is the source + matrix across a subgroup. The op returns a scalar value stored in the invocation + in the subgroup. If there are multiple values packed in an invocation, use + `indices` to specify the element to extract. + + Example: + + ```mlir + %c0 = arith.constant 0 : index + %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 + ``` + }]; + + let arguments = (ins GPU_MMAMatrix:$matrix, Variadic:$indices); + + let results = (outs AnyIntegerOrFloat:$res); + + let assemblyFormat = [{ + $matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res) + }]; +} + +def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert", + [Pure, + TypesMatchWith<"value type matches element type of mma_matrix", + "matrix", "value", + "::llvm::cast($_self).getElementType()"> ]>{ + + let summary = "Insert a value into GPU warp by invocation and indices"; + + let description = [{ + The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix` + by the invocation in a subgroup. + + This operation takes scalar value as its first operand and `!gpu.mma_matrix` + as its second operand. It is the matrix across a subgroup. The op inserts the + scalar value stored in the invocation in the subgroup to the matrix. If there + are multiple values packed in an invocation, use `indices` to specify the + location to insert in the packing. + + The op returns `!gpu.mma_matrix` with the updated value. + + Example: + + ```mlir + %c0 = arith.constant 0 : index + %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> + -> !gpu.mma_matrix<16x16xf16, "COp"> + ``` + }]; + + let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix, + Variadic:$indices); + + let results = (outs GPU_MMAMatrix:$res); + + let assemblyFormat = [{ + $value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res) + }]; +} + def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">; def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">; def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">; diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index df2da138d3b52..be76262f526d6 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final } }; +/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative +/// matrix ops. +struct WmmaExtractOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value matrix = adaptor.getMatrix(); + auto coopType = + getTypeConverter()->convertType( + matrix.getType()); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + SmallVector intValues; + for (Value val : op.getIndices()) { + if (auto constOp = val.getDefiningOp()) { + intValues.push_back(static_cast(constOp.value())); + } else { + return rewriter.notifyMatchFailure(op, "indices must be constants"); + } + } + + Type elementType = coopType.getElementType(); + rewriter.replaceOpWithNewOp( + op, elementType, matrix, rewriter.getI32ArrayAttr(intValues)); + return success(); + } +}; + +/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative +/// matrix ops. +struct WmmaInsertOpToSPIRVLowering final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value value = adaptor.getValue(); + Value matrix = adaptor.getMatrix(); + auto coopType = getTypeConverter()->convertType(matrix.getType()); + if (!coopType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + SmallVector intValues; + for (Value val : op.getIndices()) { + if (auto constOp = val.getDefiningOp()) { + intValues.push_back(static_cast(constOp.value())); + } else { + return rewriter.notifyMatchFailure(op, "indices must be constants"); + } + } + + rewriter.replaceOpWithNewOp( + op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues)); + return success(); + } +}; + /// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for /// the default case. struct WmmaElementwiseOpToSPIRVDefaultLowering final @@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns( MLIRContext *context = patterns.getContext(); patterns.add(converter, context); // Give the following patterns higher benefit to prevail over the default one. patterns.add(converter, context, diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir index 477f344b1ae5f..3e8a3b21e7e94 100644 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir @@ -93,6 +93,33 @@ module attributes { gpu.return } + // CHECK-LABEL: spirv.func @gpu_wmma_extract_op + // CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA> + gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">, + %ptr: memref<16x16xf32, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA> + %c0 = arith.constant 0 : index + %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 + memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class> + gpu.return + } + + // CHECK-LABEL: spirv.func @gpu_wmma_insert_op + // CHECK-SAME: %[[ARG0:.+]]: f16 + // CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + gpu.func @gpu_wmma_insert_op(%val: f16, + %m: !gpu.mma_matrix<16x16xf16, "COp">, + %ptr: memref<16x16xf16, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> + %c0 = arith.constant 0 : index + %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} : + !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> + gpu.return + } + // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> // CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> From b6910549d818f8c83afad22304635972b0d7dec7 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Tue, 20 May 2025 11:06:11 +0100 Subject: [PATCH 2/4] Add a parsing/printing test --- mlir/test/Dialect/GPU/ops.mlir | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 99915c493ea46..0364fc47b9308 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -430,6 +430,20 @@ module attributes {gpu.container_module} { gpu.wait [%token16] return } + + // CHECK-LABEL: func @extract_insert_mma + func.func @extract_insert_mma(%src : !gpu.mma_matrix<16x16xf32, "COp">, + %ptr: memref<16x16xf32>) { + %zero = arith.constant 0.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: gpu.subgroup_mma_extract + %val = gpu.subgroup_mma_extract %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32 + %m = gpu.subgroup_mma_constant_matrix %zero : !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: gpu.subgroup_mma_insert + %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp"> + gpu.subgroup_mma_store_matrix %s0, %ptr[%c0, %c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32> + return + } } // Just check that this doesn't crash. From 96fcae522342fcd1dd43be631479d90e5c1528e3 Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Thu, 22 May 2025 10:28:02 +0100 Subject: [PATCH 3/4] add more description about how mma matrix is stored --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index cb363b501851b..4c107a487aa4d 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1933,8 +1933,9 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract", This operation takes `!gpu.mma_matrix` as its first operand. It is the source matrix across a subgroup. The op returns a scalar value stored in the invocation - in the subgroup. If there are multiple values packed in an invocation, use - `indices` to specify the element to extract. + in the subgroup. The values of !gpu.mma_matrix are stored across multiple + threads in the subgroup. If there are multiple values packed in a thread, use + `indices` to specify the element in the local thread to extract. Example: @@ -1967,7 +1968,8 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert", This operation takes scalar value as its first operand and `!gpu.mma_matrix` as its second operand. It is the matrix across a subgroup. The op inserts the - scalar value stored in the invocation in the subgroup to the matrix. If there + scalar value stored in the invocation in the subgroup to the matrix. The values + of !gpu.mma_matrix are stored across multiple threads in the subgroup. If there are multiple values packed in an invocation, use `indices` to specify the location to insert in the packing. From c7269d3948b604075927f0bc87e57d9bfff707cf Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Fri, 23 May 2025 07:24:43 +0100 Subject: [PATCH 4/4] refine descriptions and rename operations --- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 46 ++++++++++++------- .../Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp | 8 ++-- .../wmma-ops-to-spirv-khr-coop-matrix.mlir | 12 ++--- mlir/test/Dialect/GPU/ops.mlir | 8 ++-- 4 files changed, 44 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index 4c107a487aa4d..fb27630ed3b48 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1919,7 +1919,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix", }]; } -def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract", +def GPU_SubgroupMmaExtractThreadLocalOp : GPU_Op<"subgroup_mma_extract_thread_local", [Pure, TypesMatchWith<"value type matches element type of mma_matrix", "matrix", "res", @@ -1928,20 +1928,28 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract", let summary = "Extract a value from GPU warp by invocation and indices"; let description = [{ - The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix` - by the invocation in a subgroup. + The `gpu.subgroup_mma_extract_thread_local` operation extracts a value from `!gpu.mma_matrix` + that is stored at subgroup level. This operation takes `!gpu.mma_matrix` as its first operand. It is the source matrix across a subgroup. The op returns a scalar value stored in the invocation - in the subgroup. The values of !gpu.mma_matrix are stored across multiple - threads in the subgroup. If there are multiple values packed in a thread, use - `indices` to specify the element in the local thread to extract. + in the subgroup. + + Since `matrix` is packed into the the threads within a subgroup, `indices` are + the indices into the values stored by each thread. That is, an index of 0 (or [0, 0]) + does not necessarily refer to the first element of the matrix, but the first element + that a particular thread holds. + + The mapping of matrix elements to threads is not defined by this operation and may + not be defined by some lowerings (such as the lowering to SPIR-V). However, if the + size of the subgroup is S, then `subgroup_mma_extract_thread_local` at each index in + `[0, (M * N) / S)` will have the entire matrix extracted across the subgroup. Example: ```mlir %c0 = arith.constant 0 : index - %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 + %val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 ``` }]; @@ -1954,7 +1962,7 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract", }]; } -def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert", +def GPU_SubgroupMmaInsertThreadLocalOp : GPU_Op<"subgroup_mma_insert_thread_local", [Pure, TypesMatchWith<"value type matches element type of mma_matrix", "matrix", "value", @@ -1963,15 +1971,21 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert", let summary = "Insert a value into GPU warp by invocation and indices"; let description = [{ - The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix` - by the invocation in a subgroup. + The `gpu.subgroup_mma_insert_thread_local` operation inserts a value to `!gpu.mma_matrix` + that is stored at subgroup level. This operation takes scalar value as its first operand and `!gpu.mma_matrix` - as its second operand. It is the matrix across a subgroup. The op inserts the - scalar value stored in the invocation in the subgroup to the matrix. The values - of !gpu.mma_matrix are stored across multiple threads in the subgroup. If there - are multiple values packed in an invocation, use `indices` to specify the - location to insert in the packing. + as its second operand. The op inserts the scalar value to the matrix. + + Since `matrix` is packed into the the threads within a subgroup, `indices` are + the indices into the values stored by each thread. That is, an index of 0 (or [0, 0]) + does not necessarily refer to the first element of the matrix, but the first element + that a particular thread holds. + + The mapping of matrix elements to threads is not defined by this operation and may + not be defined by some lowerings (such as the lowering to SPIR-V). However, if the + size of the subgroup is S, then `subgroup_mma_insert_thread_local` at each index in + `[0, (M * N) / S)` will have the entire matrix inserted across the subgroup. The op returns `!gpu.mma_matrix` with the updated value. @@ -1979,7 +1993,7 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert", ```mlir %c0 = arith.constant 0 : index - %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> + %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp"> ``` }]; diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp index be76262f526d6..d2f5e35853550 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp @@ -114,11 +114,11 @@ struct WmmaConstantOpToSPIRVLowering final /// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative /// matrix ops. struct WmmaExtractOpToSPIRVLowering final - : OpConversionPattern { + : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor, + matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value matrix = adaptor.getMatrix(); auto coopType = @@ -146,11 +146,11 @@ struct WmmaExtractOpToSPIRVLowering final /// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative /// matrix ops. struct WmmaInsertOpToSPIRVLowering final - : OpConversionPattern { + : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor, + matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value value = adaptor.getValue(); Value matrix = adaptor.getMatrix(); diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir index 3e8a3b21e7e94..7ef3711ebe28b 100644 --- a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir @@ -93,28 +93,28 @@ module attributes { gpu.return } - // CHECK-LABEL: spirv.func @gpu_wmma_extract_op + // CHECK-LABEL: spirv.func @gpu_wmma_extract_thread_local_op // CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA> - gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">, + gpu.func @gpu_wmma_extract_thread_local_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">, %ptr: memref<16x16xf32, #spirv.storage_class>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA> %c0 = arith.constant 0 : index - %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 + %val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32 memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class> gpu.return } - // CHECK-LABEL: spirv.func @gpu_wmma_insert_op + // CHECK-LABEL: spirv.func @gpu_wmma_insert_thread_local_op // CHECK-SAME: %[[ARG0:.+]]: f16 // CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> - gpu.func @gpu_wmma_insert_op(%val: f16, + gpu.func @gpu_wmma_insert_thread_local_op(%val: f16, %m: !gpu.mma_matrix<16x16xf16, "COp">, %ptr: memref<16x16xf16, #spirv.storage_class>) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc> %c0 = arith.constant 0 : index - %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp"> gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class> gpu.return diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 0364fc47b9308..9dbe16774f517 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -436,11 +436,11 @@ module attributes {gpu.container_module} { %ptr: memref<16x16xf32>) { %zero = arith.constant 0.0 : f32 %c0 = arith.constant 0 : index - // CHECK: gpu.subgroup_mma_extract - %val = gpu.subgroup_mma_extract %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32 + // CHECK: gpu.subgroup_mma_extract_thread_local + %val = gpu.subgroup_mma_extract_thread_local %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32 %m = gpu.subgroup_mma_constant_matrix %zero : !gpu.mma_matrix<16x16xf32, "COp"> - // CHECK: gpu.subgroup_mma_insert - %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: gpu.subgroup_mma_insert_thread_local + %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp"> gpu.subgroup_mma_store_matrix %s0, %ptr[%c0, %c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32> return }