@@ -1921,6 +1921,95 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
19211921 }];
19221922}
19231923
1924+ def GPU_SubgroupMmaExtractThreadLocalOp : GPU_Op<"subgroup_mma_extract_thread_local",
1925+ [Pure,
1926+ TypesMatchWith<"value type matches element type of mma_matrix",
1927+ "matrix", "res",
1928+ "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
1929+
1930+ let summary = "Extract a value from GPU warp by invocation and indices";
1931+
1932+ let description = [{
1933+ The `gpu.subgroup_mma_extract_thread_local` operation extracts a value from `!gpu.mma_matrix`
1934+ that is stored at subgroup level.
1935+
1936+ This operation takes `!gpu.mma_matrix` as its first operand. It is the source
1937+ matrix across a subgroup. The op returns a scalar value stored in the invocation
1938+ in the subgroup.
1939+
1940+ Since `matrix` is packed into the the threads within a subgroup, `indices` are
1941+ the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1942+ does not necessarily refer to the first element of the matrix, but the first element
1943+ that a particular thread holds.
1944+
1945+ The mapping of matrix elements to threads is not defined by this operation and may
1946+ not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1947+ size of the subgroup is S, then `subgroup_mma_extract_thread_local` at each index in
1948+ `[0, (M * N) / S)` will have the entire matrix extracted across the subgroup.
1949+
1950+ Example:
1951+
1952+ ```mlir
1953+ %c0 = arith.constant 0 : index
1954+ %val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
1955+ ```
1956+ }];
1957+
1958+ let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
1959+
1960+ let results = (outs AnyIntegerOrFloat:$res);
1961+
1962+ let assemblyFormat = [{
1963+ $matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
1964+ }];
1965+ }
1966+
1967+ def GPU_SubgroupMmaInsertThreadLocalOp : GPU_Op<"subgroup_mma_insert_thread_local",
1968+ [Pure,
1969+ TypesMatchWith<"value type matches element type of mma_matrix",
1970+ "matrix", "value",
1971+ "::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
1972+
1973+ let summary = "Insert a value into GPU warp by invocation and indices";
1974+
1975+ let description = [{
1976+ The `gpu.subgroup_mma_insert_thread_local` operation inserts a value to `!gpu.mma_matrix`
1977+ that is stored at subgroup level.
1978+
1979+ This operation takes scalar value as its first operand and `!gpu.mma_matrix`
1980+ as its second operand. The op inserts the scalar value to the matrix.
1981+
1982+ Since `matrix` is packed into the the threads within a subgroup, `indices` are
1983+ the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1984+ does not necessarily refer to the first element of the matrix, but the first element
1985+ that a particular thread holds.
1986+
1987+ The mapping of matrix elements to threads is not defined by this operation and may
1988+ not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1989+ size of the subgroup is S, then `subgroup_mma_insert_thread_local` at each index in
1990+ `[0, (M * N) / S)` will have the entire matrix inserted across the subgroup.
1991+
1992+ The op returns `!gpu.mma_matrix` with the updated value.
1993+
1994+ Example:
1995+
1996+ ```mlir
1997+ %c0 = arith.constant 0 : index
1998+ %s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
1999+ -> !gpu.mma_matrix<16x16xf16, "COp">
2000+ ```
2001+ }];
2002+
2003+ let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
2004+ Variadic<Index>:$indices);
2005+
2006+ let results = (outs GPU_MMAMatrix:$res);
2007+
2008+ let assemblyFormat = [{
2009+ $value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
2010+ }];
2011+ }
2012+
19242013def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
19252014def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
19262015def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;
0 commit comments