@@ -1919,7 +1919,7 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
19191919 }];
19201920}
19211921
1922- def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract ",
1922+ def GPU_SubgroupMmaExtractThreadLocalOp : GPU_Op<"subgroup_mma_extract_thread_local ",
19231923 [Pure,
19241924 TypesMatchWith<"value type matches element type of mma_matrix",
19251925 "matrix", "res",
@@ -1928,20 +1928,28 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
19281928 let summary = "Extract a value from GPU warp by invocation and indices";
19291929
19301930 let description = [{
1931- The `gpu.subgroup_mma_extract ` operation extracts a value from `!gpu.mma_matrix`
1932- by the invocation in a subgroup.
1931+ The `gpu.subgroup_mma_extract_thread_local ` operation extracts a value from `!gpu.mma_matrix`
1932+ that is stored at subgroup level .
19331933
19341934 This operation takes `!gpu.mma_matrix` as its first operand. It is the source
19351935 matrix across a subgroup. The op returns a scalar value stored in the invocation
1936- in the subgroup. The values of !gpu.mma_matrix are stored across multiple
1937- threads in the subgroup. If there are multiple values packed in a thread, use
1938- `indices` to specify the element in the local thread to extract.
1936+ in the subgroup.
1937+
1938+ Since `matrix` is packed into the the threads within a subgroup, `indices` are
1939+ the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1940+ does not necessarily refer to the first element of the matrix, but the first element
1941+ that a particular thread holds.
1942+
1943+ The mapping of matrix elements to threads is not defined by this operation and may
1944+ not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1945+ size of the subgroup is S, then `subgroup_mma_extract_thread_local` at each index in
1946+ `[0, (M * N) / S)` will have the entire matrix extracted across the subgroup.
19391947
19401948 Example:
19411949
19421950 ```mlir
19431951 %c0 = arith.constant 0 : index
1944- %val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
1952+ %val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
19451953 ```
19461954 }];
19471955
@@ -1954,7 +1962,7 @@ def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
19541962 }];
19551963}
19561964
1957- def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert ",
1965+ def GPU_SubgroupMmaInsertThreadLocalOp : GPU_Op<"subgroup_mma_insert_thread_local ",
19581966 [Pure,
19591967 TypesMatchWith<"value type matches element type of mma_matrix",
19601968 "matrix", "value",
@@ -1963,23 +1971,29 @@ def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
19631971 let summary = "Insert a value into GPU warp by invocation and indices";
19641972
19651973 let description = [{
1966- The `gpu.subgroup_mma_insert ` operation inserts a value to `!gpu.mma_matrix`
1967- by the invocation in a subgroup.
1974+ The `gpu.subgroup_mma_insert_thread_local ` operation inserts a value to `!gpu.mma_matrix`
1975+ that is stored at subgroup level .
19681976
19691977 This operation takes scalar value as its first operand and `!gpu.mma_matrix`
1970- as its second operand. It is the matrix across a subgroup. The op inserts the
1971- scalar value stored in the invocation in the subgroup to the matrix. The values
1972- of !gpu.mma_matrix are stored across multiple threads in the subgroup. If there
1973- are multiple values packed in an invocation, use `indices` to specify the
1974- location to insert in the packing.
1978+ as its second operand. The op inserts the scalar value to the matrix.
1979+
1980+ Since `matrix` is packed into the the threads within a subgroup, `indices` are
1981+ the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1982+ does not necessarily refer to the first element of the matrix, but the first element
1983+ that a particular thread holds.
1984+
1985+ The mapping of matrix elements to threads is not defined by this operation and may
1986+ not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1987+ size of the subgroup is S, then `subgroup_mma_insert_thread_local` at each index in
1988+ `[0, (M * N) / S)` will have the entire matrix inserted across the subgroup.
19751989
19761990 The op returns `!gpu.mma_matrix` with the updated value.
19771991
19781992 Example:
19791993
19801994 ```mlir
19811995 %c0 = arith.constant 0 : index
1982- %s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
1996+ %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
19831997 -> !gpu.mma_matrix<16x16xf16, "COp">
19841998 ```
19851999 }];
0 commit comments