-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][NVVM] Support for dense and sparse MMA with block scaling #170566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Kirill Vedernikov (kvederni) ChangesThis change adds dense and sparse MMA with block scaling intrinsics to MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0. Patch is 121.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170566.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a96d65d3fcacd..1faa435fca6f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2499,6 +2499,30 @@ class NVVM_MMA_OPS {
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);
+ // Block scale MMA operations (dense)
+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+ [GEOM<16,8,32>],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["f32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+ mxf4_mma_ops, mxf8f6f4_mma_ops);
+
+ // Block scale sparse MMA operations
+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,128>],
+ ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+ [GEOM<16,8,64>],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -3332,7 +3356,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
The optional `orderedMetadata` attribute specifies the metadata ordering:
- Absence (default): Uses standard sparse metadata ordering
- Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
-
+
The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
- `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
- Only valid with ordered metadata and m16n8k64 shape
@@ -3347,7 +3371,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
sparseMetadata[%meta] selector[%sel]
{shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
: (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-
+
// With ordered metadata:
%d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
sparseMetadata[%meta] selector[%sel]
@@ -3416,6 +3440,429 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
let hasVerifier = 1;
}
+def ScaleVecSize1X : I32EnumAttrCase<"X1", 0, "x1">;
+def ScaleVecSize2X : I32EnumAttrCase<"X2", 1, "x2">;
+def ScaleVecSize4X : I32EnumAttrCase<"X4", 2, "x4">;
+
+def ScaleVecSize : I32EnumAttr<
+ "ScaleVecSize",
+ "MMA Scale Vector Sizes",
+ [ScaleVecSize1X, ScaleVecSize2X, ScaleVecSize4X]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def ScaleVecSizeAttr : EnumAttr<NVVM_Dialect, ScaleVecSize, "scale_vec_size"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def UE8M0 : I32EnumAttrCase<"UE8M0", 0, "ue8m0">;
+def UE4M3 : I32EnumAttrCase<"UE4M3", 1, "ue4m3">;
+
+def BlockScaleFormat : I32EnumAttr<
+ "BlockScaleFormat",
+ "MMA Block Scale Format",
+ [UE8M0, UE4M3]
+> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def BlockScaleFormatAttr : EnumAttr<NVVM_Dialect, BlockScaleFormat, "block_scale_format"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def MMABlockScaleKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">;
+def MMABlockScaleKindMXF4 : I32EnumAttrCase<"MXF4", 1, "mxf4">;
+def MMABlockScaleKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">;
+
+def MMABlockScaleKind : I32EnumAttr<
+ "MMABlockScaleKind",
+ "Block Scale Kind",
+ [MMABlockScaleKindMXF8F6F4, MMABlockScaleKindMXF4, MMABlockScaleKindMXF4NVF4]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+
+def MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, MMABlockScaleKind, "block_scale_kind"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+/// Generate enum value of the mma.block_scale intrinsic.
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
+/// Generate enum value of the mma.sp.block_scale intrinsic.
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string id = "llvm::Intrinsic::nvvm_mma_sp_ordered_metadata_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
+// Returns true if this combination is supported for MMA.BLOCK_SCALE ops.
+// This references the NVVM_MMA_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x")),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_2x"),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_4x"),
+ !eq(stype, "ue4m3")) : true,
+ !and(!eq(geom, "m16n8k32"),
+ !eq(kind, "mxf8f6f4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x")),
+ !eq(stype, "ue8m0")) : true,
+ true: false
+ );
+}
+
+// Returns true if this combination is supported for MMA.SP.BLOCK_SCALE ops.
+// This references the NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x"))): true,
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale_vec_size, ".scale_2x")): true,
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue4m3"),
+ !eq(scale_vec_size, ".scale_4x")): true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf8f6f4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x"))): true,
+ true: false
+ );
+}
+
+/// Helper to create the mapping between the configuration and the mma.block_scale
+/// intrinsic enum value.
+class MMA_BLOCK_SCALE_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_block_scale_ops,
+ !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+ !foreach(stype, ["ue8m0", "ue4m3"],
+ !if(NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # kind # "\" == stringifyEnum(kind)"
+ # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+ # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+ # " return " #
+ MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // stype
+ ) // scale_vec_size
+ ) // kind
+ ); // all_mma_block_scale_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.block_scale
+/// intrinsic enum value.
+class MMA_SP_BLOCK_SCALE_INTR {
+ list<list<list<list<string>>>> cond0 =
+ !foreach(op, NVVM_MMA_OPS.all_mma_sp_block_scale_ops,
+ !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+ !foreach(stype, ["ue8m0", "ue4m3"],
+ !if(NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+ "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+ # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+ # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+ # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+ # " && \"" # kind # "\" == stringifyEnum(kind)"
+ # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+ # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+ # " return " #
+ MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+ "") // if supported
+ ) // stype
+ ) // scale_vec_size
+ ) // kind
+ ); // all_mma_sp_block_scale_ops
+ list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+ !listconcat(acc, el));
+ list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+ list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+ string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+// Common base class for MMA block scale operations (dense and sparse)
+class NVVM_MmaBlockScaleBase<string mnemonic, list<Trait> traits = []> :
+ NVVM_Op<mnemonic, !listconcat([AttrSizedOperandSegments], traits)> {
+
+ let results = (outs LLVM_AnyStruct:$res);
+
+ // Common attributes shared by both dense and sparse variants
+ dag commonArguments = (ins
+ NVVM_MMAShapeAttr:$shape,
+ OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+ OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+ ScaleVecSizeAttr:$scaleVecSize,
+ BlockScaleFormatAttr:$blockScaleFormat,
+ MMABlockScaleKindAttr:$kind);
+
+ // Common variadic operands for A, B, C matrices
+ dag commonVariadicOperands = (ins
+ Variadic<LLVM_Type>:$operandA,
+ Variadic<LLVM_Type>:$operandB,
+ Variadic<LLVM_Type>:$operandC);
+
+ // Common scale operands for both A and B
+ dag commonScaleOperands = (ins
+ I32:$scaleAData,
+ I16:$byteIdA,
+ I16:$threadIdA,
+ I32:$scaleBData,
+ I16:$byteIdB,
+ I16:$threadIdB);
+
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum,
+ mlir::NVVM::ScaleVecSize scaleVecSize,
+ mlir::NVVM::BlockScaleFormat blockScaleFormat,
+ mlir::NVVM::MMABlockScaleKind kind) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+ auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+ switch (svs) {
+ case ScaleVecSize::X1: return ".scale_1x";
+ case ScaleVecSize::X2: return ".scale_2x";
+ case ScaleVecSize::X4: return ".scale_4x";
+ }
+ return "";
+ };
+ }],
+ MMA_BLOCK_SCALE_INTR<>.id, [{
+ return 0;
+ }
+
+ // Common declarations - implementations in NVVMDialect.cpp
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let hasCustomAssemblyFormat = 1;
+ let hasVerifier = 1;
+}
+
+def NVVM_MmaBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.block_scale"> {
+
+ let summary = "cooperative matrix-multiply and accumulate with block scaling";
+
+ let description = [{
+ The `nvvm.mma.block_scale` operation collectively performs the operation
+ `D = matmul(A * SF_A, B * SF_B) + C` using all threads in a warp.
+
+ A, B, C and D are dense matrices and SF_A and SF_B are scaling factors.
+ Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+ and the data type must be either ue8m0 or ue4m3.
+
+ All the threads in the warp must execute the same `mma.block_scale` operation.
+
+ This operation follows the same design pattern as `nvvm.mma.sync`, with additional
+ scaling operands for both A and B matrices.
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4nvf4>}
+ : (vector<4xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+ ```
+ }];
+
+ // Combine common attributes and operands
+ let arguments = !con(commonArguments, commonVariadicOperands, commonScaleOperands);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+ "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+ "ScaleVecSize":$scaleVecSize,
+ "BlockScaleFormat":$blockScaleFormat,
+ "MMABlockScaleKind":$kind)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::MmaBlockScaleOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
+ }];
+}
+
+def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> {
+
+ let summary = "cooperative sparse matrix-multiply and accumulate with block scaling";
+
+ let description = [{
+ The `nvvm.mma.sp.block_scale` operation collectively performs the operation
+ `D = matmul(A_sparse * SF_A, B * SF_B) + C` using all threads in a warp.
+
+ A is a sparse matrix, and B, C and D are dense matrices.
+ SF_A and SF_B are scaling factors.
+ Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+ and the data type must be either ue8m0 or ue4m3.
+
+ This operation is similar to `nvvm.mma.block_scale` but with structured sparsity
+ in the A operand. The sparsity follows the 2:4 structured sparse pattern
+ where 2 out of every 4 elements are non-zero.
+
+ All the threads in the warp must execute the same `mma.sp.block_scale` operation.
+
+ The `sparseMetadata` operand provides the sparsity indices that indicate
+ which elements in the A operand are non-zero. The `sparsitySelector`
+ controls how the indices are distributed among threads in the warp and
+ should typically be 0 or 1.
+
+ This operation follows the same design pattern as `nvvm.mma.sp.sync`, with additional
+ scaling operands for both A and B matrices. Note that sparse block scale operations
+ always use ordered metadata (sm_90+).
+
+ Example:
+ ```mlir
+ %d = nvvm.mma.sp.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+ sparseMetadata[%meta] selector[%sel]
+ scaleA[%scaleAData, %byteIdA, %threadIdA]
+ scaleB[%scaleBData, %byteIdB, %threadIdB]
+ {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+ multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+ multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+ scaleVecSize = #nvvm.scale_vec_size<x2>,
+ blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+ kind = #nvvm.block_scale_kind<mxf4>}
+ : (vector<2xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+ ```
+ }];
+
+ // Sparse-specific attributes and operands
+ dag sparseSpecificArguments = (ins
+ UnitAttr:$orderedMetadata);
+
+ dag sparseSpecificOperands = (ins
+ I32:$sparseMetadata,
+ I32:$sparsitySelector);
+
+ // Combine common and sparse-specific attributes and operands
+ let arguments = !con(commonArguments, sparseSpecificArguments,
+ commonVariadicOperands, sparseSpecificOperands,
+ commonScaleOperands);
+
+ // Override extraClassDeclaration to use sparse intrinsics
+ let extraClassDeclaration = !strconcat([{
+ static llvm::Intrinsic::ID getIntrinsicID(
+ int64_t m, int64_t n, uint64_t k,
+ mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+ mlir::NVVM::MMATypes eltypeCEnum,
+ mlir::NVVM::ScaleVecSize scaleVecSize,
+ mlir::NVVM::BlockScaleFormat blockScaleFormat,
+ mlir::NVVM::MMABlockScaleKind kind) {
+ llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+ llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+ llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+ auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+ switch (svs) {
+ case ScaleVecSize::X1: return ".scale_1x";
+ case ScaleVecSize::X2: return ".scale_2x";
+ case ScaleVecSize::X4: return ".scale_4x";
+ }
+ return "";
+ };
+ }],
+ MMA_SP_BLOCK_SCALE_INTR<>.id, [{
+ return 0;
+ }
+
+ MMATypes accumPtxType();
+ MMATypes resultPtxType();
+
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }]);
+
+ let builders = [
+ OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+ "ValueRange":$operandB, "ValueRange":$operandC,
+ "Value":$sparseMetadata, "Value":$sparsitySelector,
+ "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+ "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+ "ArrayRef<int64_t>":$shape,
+ "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+ "ScaleVecSize":$scaleVecSize,
+ "BlockScaleFormat":$blockScaleFormat,
+ "MMABlockScaleKind":$kind)>
+ ];
+
+ string llvmBuilder = [{
+ auto [id, a...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
| // MMA Block Scale Operations - Shared Helpers | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| namespace { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kvederni Why this should be in an anonymous namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can mark them static for internal linkage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kirill explained offline that it is a common practice in MLIR, I will let @grypp take a call on this one
| result = | ||
| emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " | ||
| "attributes for mma.m16n8k64"); | ||
| } else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please add braces around this for consistency?
| })); | ||
| } | ||
|
|
||
| MMATypes MmaBlockScaleOp::accumPtxType() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we replace the function with the body in the callsite?
This change adds dense and sparse MMA with block scaling intrinsics to MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.