Skip to content

Commit c7269d3

Browse files
committed
refine descriptions and rename operations
1 parent 96fcae5 commit c7269d3

File tree

4 files changed

+44
-30
lines changed

4 files changed

+44
-30
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,7 +1919,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
19191919
}];
19201920
}
19211921

1922-
def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
1922+
def GPU_SubgroupMmaExtractThreadLocalOp : GPU_Op<"subgroup_mma_extract_thread_local",
19231923
[Pure,
19241924
TypesMatchWith<"value type matches element type of mma_matrix",
19251925
"matrix", "res",
@@ -1928,20 +1928,28 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
19281928
let summary = "Extract a value from GPU warp by invocation and indices";
19291929

19301930
let description = [{
1931-
The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
1932-
by the invocation in a subgroup.
1931+
The `gpu.subgroup_mma_extract_thread_local` operation extracts a value from `!gpu.mma_matrix`
1932+
that is stored at subgroup level.
19331933

19341934
This operation takes `!gpu.mma_matrix` as its first operand. It is the source
19351935
matrix across a subgroup. The op returns a scalar value stored in the invocation
1936-
in the subgroup. The values of !gpu.mma_matrix are stored across multiple
1937-
threads in the subgroup. If there are multiple values packed in a thread, use
1938-
`indices` to specify the element in the local thread to extract.
1936+
in the subgroup.
1937+
1938+
Since `matrix` is packed into the the threads within a subgroup, `indices` are
1939+
the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1940+
does not necessarily refer to the first element of the matrix, but the first element
1941+
that a particular thread holds.
1942+
1943+
The mapping of matrix elements to threads is not defined by this operation and may
1944+
not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1945+
size of the subgroup is S, then `subgroup_mma_extract_thread_local` at each index in
1946+
`[0, (M * N) / S)` will have the entire matrix extracted across the subgroup.
19391947

19401948
Example:
19411949

19421950
```mlir
19431951
%c0 = arith.constant 0 : index
1944-
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
1952+
%val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
19451953
```
19461954
}];
19471955

@@ -1954,7 +1962,7 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
19541962
}];
19551963
}
19561964

1957-
def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
1965+
def GPU_SubgroupMmaInsertThreadLocalOp : GPU_Op<"subgroup_mma_insert_thread_local",
19581966
[Pure,
19591967
TypesMatchWith<"value type matches element type of mma_matrix",
19601968
"matrix", "value",
@@ -1963,23 +1971,29 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
19631971
let summary = "Insert a value into GPU warp by invocation and indices";
19641972

19651973
let description = [{
1966-
The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
1967-
by the invocation in a subgroup.
1974+
The `gpu.subgroup_mma_insert_thread_local` operation inserts a value to `!gpu.mma_matrix`
1975+
that is stored at subgroup level.
19681976

19691977
This operation takes scalar value as its first operand and `!gpu.mma_matrix`
1970-
as its second operand. It is the matrix across a subgroup. The op inserts the
1971-
scalar value stored in the invocation in the subgroup to the matrix. The values
1972-
of !gpu.mma_matrix are stored across multiple threads in the subgroup. If there
1973-
are multiple values packed in an invocation, use `indices` to specify the
1974-
location to insert in the packing.
1978+
as its second operand. The op inserts the scalar value to the matrix.
1979+
1980+
Since `matrix` is packed into the the threads within a subgroup, `indices` are
1981+
the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1982+
does not necessarily refer to the first element of the matrix, but the first element
1983+
that a particular thread holds.
1984+
1985+
The mapping of matrix elements to threads is not defined by this operation and may
1986+
not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1987+
size of the subgroup is S, then `subgroup_mma_insert_thread_local` at each index in
1988+
`[0, (M * N) / S)` will have the entire matrix inserted across the subgroup.
19751989

19761990
The op returns `!gpu.mma_matrix` with the updated value.
19771991

19781992
Example:
19791993

19801994
```mlir
19811995
%c0 = arith.constant 0 : index
1982-
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
1996+
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
19831997
-> !gpu.mma_matrix<16x16xf16, "COp">
19841998
```
19851999
}];

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ struct WmmaConstantOpToSPIRVLowering final
114114
/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115115
/// matrix ops.
116116
struct WmmaExtractOpToSPIRVLowering final
117-
: OpConversionPattern<gpu::SubgroupMmaExtractOp> {
117+
: OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
118118
using OpConversionPattern::OpConversionPattern;
119119

120120
LogicalResult
121-
matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
121+
matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
122122
ConversionPatternRewriter &rewriter) const override {
123123
Value matrix = adaptor.getMatrix();
124124
auto coopType =
@@ -146,11 +146,11 @@ struct WmmaExtractOpToSPIRVLowering final
146146
/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147147
/// matrix ops.
148148
struct WmmaInsertOpToSPIRVLowering final
149-
: OpConversionPattern<gpu::SubgroupMmaInsertOp> {
149+
: OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
150150
using OpConversionPattern::OpConversionPattern;
151151

152152
LogicalResult
153-
matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
153+
matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
154154
ConversionPatternRewriter &rewriter) const override {
155155
Value value = adaptor.getValue();
156156
Value matrix = adaptor.getMatrix();

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,28 +93,28 @@ module attributes {
9393
gpu.return
9494
}
9595

96-
// CHECK-LABEL: spirv.func @gpu_wmma_extract_op
96+
// CHECK-LABEL: spirv.func @gpu_wmma_extract_thread_local_op
9797
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
98-
gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
98+
gpu.func @gpu_wmma_extract_thread_local_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
9999
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
100100
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
101101
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
102102
%c0 = arith.constant 0 : index
103-
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
103+
%val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
104104
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
105105
gpu.return
106106
}
107107

108-
// CHECK-LABEL: spirv.func @gpu_wmma_insert_op
108+
// CHECK-LABEL: spirv.func @gpu_wmma_insert_thread_local_op
109109
// CHECK-SAME: %[[ARG0:.+]]: f16
110110
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
111-
gpu.func @gpu_wmma_insert_op(%val: f16,
111+
gpu.func @gpu_wmma_insert_thread_local_op(%val: f16,
112112
%m: !gpu.mma_matrix<16x16xf16, "COp">,
113113
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
114114
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
115115
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
116116
%c0 = arith.constant 0 : index
117-
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
117+
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
118118
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
119119
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
120120
gpu.return

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,11 +436,11 @@ module attributes {gpu.container_module} {
436436
%ptr: memref<16x16xf32>) {
437437
%zero = arith.constant 0.0 : f32
438438
%c0 = arith.constant 0 : index
439-
// CHECK: gpu.subgroup_mma_extract
440-
%val = gpu.subgroup_mma_extract %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32
439+
// CHECK: gpu.subgroup_mma_extract_thread_local
440+
%val = gpu.subgroup_mma_extract_thread_local %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32
441441
%m = gpu.subgroup_mma_constant_matrix %zero : !gpu.mma_matrix<16x16xf32, "COp">
442-
// CHECK: gpu.subgroup_mma_insert
443-
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp">
442+
// CHECK: gpu.subgroup_mma_insert_thread_local
443+
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp">
444444
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0, %c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
445445
return
446446
}

0 commit comments

Comments
 (0)