Skip to content

Conversation

@kvederni
Copy link
Contributor

@kvederni kvederni commented Dec 3, 2025

This change adds dense and sparse MMA with block scaling intrinsics to MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Kirill Vedernikov (kvederni)

Changes

This change adds dense and sparse MMA with block scaling intrinsics to MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.


Patch is 121.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170566.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+449-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+633-1)
  • (added) mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir (+525)
  • (added) mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir (+637)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index a96d65d3fcacd..1faa435fca6f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2499,6 +2499,30 @@ class NVVM_MMA_OPS {
             bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
             subint_mma_sp_ops, int_mma_sp_ops);
 
+  // Block scale MMA operations (dense)
+  list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+            [GEOM<16,8,64>],
+            ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+            [GEOM<16,8,32>],
+            ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+            ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+            ["f32"], []>.ret;
+  list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+            mxf4_mma_ops, mxf8f6f4_mma_ops);
+
+  // Block scale sparse MMA operations
+  list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,128>],
+            ["e2m1"], ["e2m1"], ["f32"], []>.ret;
+  list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+            [GEOM<16,8,64>],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f32"], []>.ret;
+  list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+            mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -3332,7 +3356,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
     The optional `orderedMetadata` attribute specifies the metadata ordering:
     - Absence (default): Uses standard sparse metadata ordering
     - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+)
-    
+
     The optional `kind` attribute specifies mixed-precision modes for FP8 operations:
     - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+)
     - Only valid with ordered metadata and m16n8k64 shape
@@ -3347,7 +3371,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
                           sparseMetadata[%meta] selector[%sel]
                           {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}}
         : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-    
+
     // With ordered metadata:
     %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
                           sparseMetadata[%meta] selector[%sel]
@@ -3416,6 +3440,429 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> {
   let hasVerifier = 1;
 }
 
+def ScaleVecSize1X  : I32EnumAttrCase<"X1", 0, "x1">;
+def ScaleVecSize2X  : I32EnumAttrCase<"X2", 1, "x2">;
+def ScaleVecSize4X  : I32EnumAttrCase<"X4", 2, "x4">;
+
+def ScaleVecSize : I32EnumAttr<
+  "ScaleVecSize",
+  "MMA Scale Vector Sizes",
+  [ScaleVecSize1X, ScaleVecSize2X, ScaleVecSize4X]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def ScaleVecSizeAttr : EnumAttr<NVVM_Dialect, ScaleVecSize, "scale_vec_size"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def UE8M0 : I32EnumAttrCase<"UE8M0", 0, "ue8m0">;
+def UE4M3 : I32EnumAttrCase<"UE4M3", 1, "ue4m3">;
+
+def BlockScaleFormat : I32EnumAttr<
+  "BlockScaleFormat",
+  "MMA Block Scale Format",
+  [UE8M0, UE4M3]
+> {
+  let cppNamespace = "::mlir::NVVM";
+  let genSpecializedAttr = 0;
+}
+
+def BlockScaleFormatAttr : EnumAttr<NVVM_Dialect, BlockScaleFormat, "block_scale_format"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def MMABlockScaleKindMXF8F6F4  : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">;
+def MMABlockScaleKindMXF4  : I32EnumAttrCase<"MXF4", 1, "mxf4">;
+def MMABlockScaleKindMXF4NVF4  : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">;
+
+def MMABlockScaleKind : I32EnumAttr<
+  "MMABlockScaleKind",
+  "Block Scale Kind",
+  [MMABlockScaleKindMXF8F6F4, MMABlockScaleKindMXF4, MMABlockScaleKindMXF4NVF4]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+
+def MMABlockScaleKindAttr : EnumAttr<NVVM_Dialect, MMABlockScaleKind, "block_scale_kind"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+/// Generate enum value of the mma.block_scale intrinsic.
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                           WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string id = "llvm::Intrinsic::nvvm_mma_block_scale"
+              # "_" # A.geom
+              # "_row_col"
+              # "_" # Kind
+              # !subst(".", "_", ScaleVecSize)
+              # signature
+              # "_" # SType;
+}
+
+/// Generate enum value of the mma.sp.block_scale intrinsic.
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                              WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string id = "llvm::Intrinsic::nvvm_mma_sp_ordered_metadata_block_scale"
+              # "_" # A.geom
+              # "_row_col"
+              # "_" # Kind
+              # !subst(".", "_", ScaleVecSize)
+              # signature
+              # "_" # SType;
+}
+
+// Returns true if this combination is supported for MMA.BLOCK_SCALE ops.
+// This references the NVVM_MMA_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+                                     string stype, string scale_vec_size> {
+  string geom = frags[0].geom;
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x")),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_2x"),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_4x"),
+         !eq(stype, "ue4m3")) : true,
+    !and(!eq(geom, "m16n8k32"),
+         !eq(kind, "mxf8f6f4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x")),
+         !eq(stype, "ue8m0")) : true,
+    true: false
+  );
+}
+
+// Returns true if this combination is supported for MMA.SP.BLOCK_SCALE ops.
+// This references the NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+                                        string stype, string scale_vec_size> {
+  string geom = frags[0].geom;
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x"))): true,
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue8m0"),
+         !eq(scale_vec_size, ".scale_2x")): true,
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue4m3"),
+         !eq(scale_vec_size, ".scale_4x")): true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf8f6f4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x"))): true,
+    true: false
+  );
+}
+
+/// Helper to create the mapping between the configuration and the mma.block_scale
+/// intrinsic enum value.
+class MMA_BLOCK_SCALE_INTR {
+  list<list<list<list<string>>>> cond0 =
+    !foreach(op, NVVM_MMA_OPS.all_mma_block_scale_ops,
+      !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+        !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+          !foreach(stype, ["ue8m0", "ue4m3"],
+            !if(NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+                "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+                # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+                # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+                # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+                # " && \"" # kind # "\" == stringifyEnum(kind)"
+                # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+                # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+                # "  return " #
+                MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+                "") // if supported
+          ) // stype
+        ) // scale_vec_size
+      ) // kind
+    ); // all_mma_block_scale_ops
+  list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+                                       !listconcat(acc, el));
+  list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+  list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+  string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+/// Helper to create the mapping between the configuration and the mma.sp.block_scale
+/// intrinsic enum value.
+class MMA_SP_BLOCK_SCALE_INTR {
+  list<list<list<list<string>>>> cond0 =
+    !foreach(op, NVVM_MMA_OPS.all_mma_sp_block_scale_ops,
+      !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+        !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"],
+          !foreach(stype, ["ue8m0", "ue4m3"],
+            !if(NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret,
+                "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k
+                # " && \"" # op[0].ptx_elt_type # "\" == eltypeA"
+                # " && \"" # op[1].ptx_elt_type # "\" == eltypeB"
+                # " && \"" # op[2].ptx_elt_type # "\" == eltypeC"
+                # " && \"" # kind # "\" == stringifyEnum(kind)"
+                # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)"
+                # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n"
+                # "  return " #
+                MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size, op[0], op[1], op[2], op[3]>.id # ";",
+                "") // if supported
+          ) // stype
+        ) // scale_vec_size
+      ) // kind
+    ); // all_mma_sp_block_scale_ops
+  list<list<list<string>>> f1 = !foldl([[[""]]], cond0, acc, el,
+                                       !listconcat(acc, el));
+  list<list<string>> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el));
+  list<string> f3 = !foldl([""], f2, acc, el, !listconcat(acc, el));
+  string id = !foldl("", f3, acc, el, acc # "\n" # el);
+}
+
+// Common base class for MMA block scale operations (dense and sparse)
+class NVVM_MmaBlockScaleBase<string mnemonic, list<Trait> traits = []> :
+    NVVM_Op<mnemonic, !listconcat([AttrSizedOperandSegments], traits)> {
+
+  let results = (outs LLVM_AnyStruct:$res);
+
+  // Common attributes shared by both dense and sparse variants
+  dag commonArguments = (ins
+           NVVM_MMAShapeAttr:$shape,
+           OptionalAttr<MMATypesAttr>:$multiplicandAPtxType,
+           OptionalAttr<MMATypesAttr>:$multiplicandBPtxType,
+           ScaleVecSizeAttr:$scaleVecSize,
+           BlockScaleFormatAttr:$blockScaleFormat,
+           MMABlockScaleKindAttr:$kind);
+
+  // Common variadic operands for A, B, C matrices
+  dag commonVariadicOperands = (ins
+           Variadic<LLVM_Type>:$operandA,
+           Variadic<LLVM_Type>:$operandB,
+           Variadic<LLVM_Type>:$operandC);
+
+  // Common scale operands for both A and B
+  dag commonScaleOperands = (ins
+             I32:$scaleAData,
+             I16:$byteIdA,
+             I16:$threadIdA,
+             I32:$scaleBData,
+             I16:$byteIdB,
+             I16:$threadIdB);
+
+  let extraClassDeclaration = !strconcat([{
+      static llvm::Intrinsic::ID getIntrinsicID(
+            int64_t m, int64_t n, uint64_t k,
+            mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+            mlir::NVVM::MMATypes eltypeCEnum,
+            mlir::NVVM::ScaleVecSize scaleVecSize,
+            mlir::NVVM::BlockScaleFormat blockScaleFormat,
+            mlir::NVVM::MMABlockScaleKind kind) {
+        llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+        llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+        llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+        auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+          switch (svs) {
+            case ScaleVecSize::X1: return ".scale_1x";
+            case ScaleVecSize::X2: return ".scale_2x";
+            case ScaleVecSize::X4: return ".scale_4x";
+          }
+          return "";
+        };
+        }],
+        MMA_BLOCK_SCALE_INTR<>.id, [{
+        return 0;
+      }
+
+      // Common declarations - implementations in NVVMDialect.cpp
+      MMATypes accumPtxType();
+      MMATypes resultPtxType();
+
+      static mlir::NVVM::IDArgPair
+      getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                            llvm::IRBuilderBase& builder);
+    }]);
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
+def NVVM_MmaBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.block_scale"> {
+
+  let summary = "cooperative matrix-multiply and accumulate with block scaling";
+
+  let description = [{
+    The `nvvm.mma.block_scale` operation collectively performs the operation
+    `D = matmul(A * SF_A, B * SF_B) + C` using all threads in a warp.
+
+    A, B, C and D are dense matrices and SF_A and SF_B are scaling factors.
+    Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+    and the data type must be either ue8m0 or ue4m3.
+
+    All the threads in the warp must execute the same `mma.block_scale` operation.
+
+    This operation follows the same design pattern as `nvvm.mma.sync`, with additional
+    scaling operands for both A and B matrices.
+
+    Example:
+    ```mlir
+    %d = nvvm.mma.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+                              scaleA[%scaleAData, %byteIdA, %threadIdA]
+                              scaleB[%scaleBData, %byteIdB, %threadIdB]
+                              {shape = #nvvm.shape<m = 16, n = 8, k = 64>,
+                               multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+                               multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+                               scaleVecSize = #nvvm.scale_vec_size<x2>,
+                               blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+                               kind = #nvvm.block_scale_kind<mxf4nvf4>}
+        : (vector<4xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+    ```
+  }];
+
+  // Combine common attributes and operands
+  let arguments = !con(commonArguments, commonVariadicOperands, commonScaleOperands);
+
+  let builders = [
+      OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+        "ValueRange":$operandB, "ValueRange":$operandC,
+        "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+        "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+        "ArrayRef<int64_t>":$shape,
+        "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+        "ScaleVecSize":$scaleVecSize,
+        "BlockScaleFormat":$blockScaleFormat,
+        "MMABlockScaleKind":$kind)>
+    ];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::MmaBlockScaleOp::getIntrinsicIDAndArgs(
+                      *op, moduleTranslation, builder);
+    $res = createIntrinsicCall(builder, id, args);
+  }];
+}
+
+def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> {
+
+  let summary = "cooperative sparse matrix-multiply and accumulate with block scaling";
+
+  let description = [{
+    The `nvvm.mma.sp.block_scale` operation collectively performs the operation
+    `D = matmul(A_sparse * SF_A, B * SF_B) + C` using all threads in a warp.
+
+    A is a sparse matrix, and B, C and D are dense matrices.
+    SF_A and SF_B are scaling factors.
+    Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4),
+    and the data type must be either ue8m0 or ue4m3.
+
+    This operation is similar to `nvvm.mma.block_scale` but with structured sparsity
+    in the A operand. The sparsity follows the 2:4 structured sparse pattern
+    where 2 out of every 4 elements are non-zero.
+
+    All the threads in the warp must execute the same `mma.sp.block_scale` operation.
+
+    The `sparseMetadata` operand provides the sparsity indices that indicate
+    which elements in the A operand are non-zero. The `sparsitySelector`
+    controls how the indices are distributed among threads in the warp and
+    should typically be 0 or 1.
+
+    This operation follows the same design pattern as `nvvm.mma.sp.sync`, with additional
+    scaling operands for both A and B matrices. Note that sparse block scale operations
+    always use ordered metadata (sm_90+).
+
+    Example:
+    ```mlir
+    %d = nvvm.mma.sp.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1]
+                                 sparseMetadata[%meta] selector[%sel]
+                                 scaleA[%scaleAData, %byteIdA, %threadIdA]
+                                 scaleB[%scaleBData, %byteIdB, %threadIdB]
+                                 {shape = #nvvm.shape<m = 16, n = 8, k = 128>,
+                                  multiplicandAPtxType = #nvvm.mma_type<e2m1>,
+                                  multiplicandBPtxType = #nvvm.mma_type<e2m1>,
+                                  scaleVecSize = #nvvm.scale_vec_size<x2>,
+                                  blockScaleFormat = #nvvm.block_scale_format<ue8m0>,
+                                  kind = #nvvm.block_scale_kind<mxf4>}
+        : (vector<2xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)>
+    ```
+  }];
+
+  // Sparse-specific attributes and operands
+  dag sparseSpecificArguments = (ins
+           UnitAttr:$orderedMetadata);
+
+  dag sparseSpecificOperands = (ins
+             I32:$sparseMetadata,
+             I32:$sparsitySelector);
+
+  // Combine common and sparse-specific attributes and operands
+  let arguments = !con(commonArguments, sparseSpecificArguments, 
+                       commonVariadicOperands, sparseSpecificOperands,
+                       commonScaleOperands);
+
+  // Override extraClassDeclaration to use sparse intrinsics
+  let extraClassDeclaration = !strconcat([{
+      static llvm::Intrinsic::ID getIntrinsicID(
+            int64_t m, int64_t n, uint64_t k,
+            mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum,
+            mlir::NVVM::MMATypes eltypeCEnum,
+            mlir::NVVM::ScaleVecSize scaleVecSize,
+            mlir::NVVM::BlockScaleFormat blockScaleFormat,
+            mlir::NVVM::MMABlockScaleKind kind) {
+        llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum);
+        llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum);
+        llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum);
+
+        auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string {
+          switch (svs) {
+            case ScaleVecSize::X1: return ".scale_1x";
+            case ScaleVecSize::X2: return ".scale_2x";
+            case ScaleVecSize::X4: return ".scale_4x";
+          }
+          return "";
+        };
+        }],
+        MMA_SP_BLOCK_SCALE_INTR<>.id, [{
+        return 0;
+      }
+
+      MMATypes accumPtxType();
+      MMATypes resultPtxType();
+
+      static mlir::NVVM::IDArgPair
+      getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                            llvm::IRBuilderBase& builder);
+    }]);
+
+  let builders = [
+      OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA,
+        "ValueRange":$operandB, "ValueRange":$operandC,
+        "Value":$sparseMetadata, "Value":$sparsitySelector,
+        "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA,
+        "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB,
+        "ArrayRef<int64_t>":$shape,
+        "std::optional<std::array<MMATypes, 2>>":$multiplicandPtxTypes,
+        "ScaleVecSize":$scaleVecSize,
+        "BlockScaleFormat":$blockScaleFormat,
+        "MMABlockScaleKind":$kind)>
+    ];
+
+  string llvmBuilder = [{
+    auto [id, a...
[truncated]

@github-actions
Copy link

github-actions bot commented Dec 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@durga4github durga4github changed the title [MLIR] Support for dense and sparse MMA with block scaling [MLIR][NVVM] Support for dense and sparse MMA with block scaling Dec 4, 2025
// MMA Block Scale Operations - Shared Helpers
//===----------------------------------------------------------------------===//

namespace {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kvederni Why this should be in an anonymous namespace?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can mark them static for internal linkage?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kirill explained offline that it is a common practice in MLIR, I will let @grypp take a call on this one

result =
emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat "
"attributes for mma.m16n8k64");
} else
Copy link
Contributor

@schwarzschild-radius schwarzschild-radius Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please add braces around this for consistency?

}));
}

MMATypes MmaBlockScaleOp::accumPtxType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we replace the function with the body in the callsite?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants