From 277da943277f20f3cc2cfcfbf1a5af07645722f2 Mon Sep 17 00:00:00 2001 From: Kirill Vedernikov Date: Wed, 3 Dec 2025 22:55:24 +0100 Subject: [PATCH 1/3] [MLIR] Support for dense and sparse MMA with block scaling --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 451 ++++++++++++- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 634 ++++++++++++++++- .../Dialect/LLVMIR/nvvm-mma-blockscale.mlir | 525 +++++++++++++++ .../LLVMIR/nvvm-mma-sparse-blockscale.mlir | 637 ++++++++++++++++++ 4 files changed, 2244 insertions(+), 3 deletions(-) create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir create mode 100644 mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index a96d65d3fcacd..1faa435fca6f9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -2499,6 +2499,30 @@ class NVVM_MMA_OPS { bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops, subint_mma_sp_ops, int_mma_sp_ops); + // Block scale MMA operations (dense) + list> mxf4_mma_ops = MMA_OPS< + [GEOM<16,8,64>], + ["e2m1"], ["e2m1"], ["f32"], []>.ret; + list> mxf8f6f4_mma_ops = MMA_OPS< + [GEOM<16,8,32>], + ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], + ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], + ["f32"], []>.ret; + list> all_mma_block_scale_ops = !listconcat( + mxf4_mma_ops, mxf8f6f4_mma_ops); + + // Block scale sparse MMA operations + list> mxf4xx_mma_sp_ops = MMA_OPS< + [GEOM<16,8,128>], + ["e2m1"], ["e2m1"], ["f32"], []>.ret; + list> mxf8f6f4_mma_sp_ops = MMA_OPS< + [GEOM<16,8,64>], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], + ["f32"], []>.ret; + list> all_mma_sp_block_scale_ops = !listconcat( + mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops); + } def NVVM_MMA_OPS : NVVM_MMA_OPS; @@ -3332,7 +3356,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { The optional `orderedMetadata` attribute specifies the metadata ordering: - Absence (default): Uses standard sparse metadata ordering - Presence: Uses ordered metadata (PTX ISA 8.5+, sm_90+) - + The optional `kind` attribute specifies mixed-precision modes for FP8 operations: - `f8f6f4`: Enables e3m2, e2m3, e2m1 FP8 types and f16 accumulator (PTX ISA 8.7+, sm_90+) - Only valid with ordered metadata and m16n8k64 shape @@ -3347,7 +3371,7 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { sparseMetadata[%meta] selector[%sel] {shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> - + // With ordered metadata: %d = nvvm.mma.sp.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] sparseMetadata[%meta] selector[%sel] @@ -3416,6 +3440,429 @@ def NVVM_MmaSpOp : NVVM_Op<"mma.sp.sync", [AttrSizedOperandSegments]> { let hasVerifier = 1; } +def ScaleVecSize1X : I32EnumAttrCase<"X1", 0, "x1">; +def ScaleVecSize2X : I32EnumAttrCase<"X2", 1, "x2">; +def ScaleVecSize4X : I32EnumAttrCase<"X4", 2, "x4">; + +def ScaleVecSize : I32EnumAttr< + "ScaleVecSize", + "MMA Scale Vector Sizes", + [ScaleVecSize1X, ScaleVecSize2X, ScaleVecSize4X]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def ScaleVecSizeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def UE8M0 : I32EnumAttrCase<"UE8M0", 0, "ue8m0">; +def UE4M3 : I32EnumAttrCase<"UE4M3", 1, "ue4m3">; + +def BlockScaleFormat : I32EnumAttr< + "BlockScaleFormat", + "MMA Block Scale Format", + [UE8M0, UE4M3] +> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def BlockScaleFormatAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def MMABlockScaleKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 0, "mxf8f6f4">; +def MMABlockScaleKindMXF4 : I32EnumAttrCase<"MXF4", 1, "mxf4">; +def MMABlockScaleKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 2, "mxf4nvf4">; + +def MMABlockScaleKind : I32EnumAttr< + "MMABlockScaleKind", + "Block Scale Kind", + [MMABlockScaleKindMXF8F6F4, MMABlockScaleKindMXF4, MMABlockScaleKindMXF4NVF4]> { + let cppNamespace = "::mlir::NVVM"; + let genSpecializedAttr = 0; +} + +def MMABlockScaleKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +/// Generate enum value of the mma.block_scale intrinsic. +class MMA_BLOCK_SCALE_NAME { + string signature = MMA_SIGNATURE.ret; + string id = "llvm::Intrinsic::nvvm_mma_block_scale" + # "_" # A.geom + # "_row_col" + # "_" # Kind + # !subst(".", "_", ScaleVecSize) + # signature + # "_" # SType; +} + +/// Generate enum value of the mma.sp.block_scale intrinsic. +class MMA_SP_BLOCK_SCALE_NAME { + string signature = MMA_SIGNATURE.ret; + string id = "llvm::Intrinsic::nvvm_mma_sp_ordered_metadata_block_scale" + # "_" # A.geom + # "_row_col" + # "_" # Kind + # !subst(".", "_", ScaleVecSize) + # signature + # "_" # SType; +} + +// Returns true if this combination is supported for MMA.BLOCK_SCALE ops. +// This references the NVVM_MMA_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td +class NVVM_MMA_BLOCK_SCALE_SUPPORTED frags, string kind, + string stype, string scale_vec_size> { + string geom = frags[0].geom; + bit ret = !cond( + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_2x")), + !eq(stype, "ue8m0")) : true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4nvf4"), + !eq(scale_vec_size, ".scale_2x"), + !eq(stype, "ue8m0")) : true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf4nvf4"), + !eq(scale_vec_size, ".scale_4x"), + !eq(stype, "ue4m3")) : true, + !and(!eq(geom, "m16n8k32"), + !eq(kind, "mxf8f6f4"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_1x")), + !eq(stype, "ue8m0")) : true, + true: false + ); +} + +// Returns true if this combination is supported for MMA.SP.BLOCK_SCALE ops. +// This references the NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED class from IntrinsicsNVVM.td +class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED frags, string kind, + string stype, string scale_vec_size> { + string geom = frags[0].geom; + bit ret = !cond( + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4"), + !eq(stype, "ue8m0"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_2x"))): true, + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4nvf4"), + !eq(stype, "ue8m0"), + !eq(scale_vec_size, ".scale_2x")): true, + !and(!eq(geom, "m16n8k128"), + !eq(kind, "mxf4nvf4"), + !eq(stype, "ue4m3"), + !eq(scale_vec_size, ".scale_4x")): true, + !and(!eq(geom, "m16n8k64"), + !eq(kind, "mxf8f6f4"), + !eq(stype, "ue8m0"), + !or(!eq(scale_vec_size, ""), + !eq(scale_vec_size, ".scale_1x"))): true, + true: false + ); +} + +/// Helper to create the mapping between the configuration and the mma.block_scale +/// intrinsic enum value. +class MMA_BLOCK_SCALE_INTR { + list>>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_mma_block_scale_ops, + !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"], + !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"], + !foreach(stype, ["ue8m0", "ue4m3"], + !if(NVVM_MMA_BLOCK_SCALE_SUPPORTED.ret, + "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k + # " && \"" # op[0].ptx_elt_type # "\" == eltypeA" + # " && \"" # op[1].ptx_elt_type # "\" == eltypeB" + # " && \"" # op[2].ptx_elt_type # "\" == eltypeC" + # " && \"" # kind # "\" == stringifyEnum(kind)" + # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)" + # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n" + # " return " # + MMA_BLOCK_SCALE_NAME.id # ";", + "") // if supported + ) // stype + ) // scale_vec_size + ) // kind + ); // all_mma_block_scale_ops + list>> f1 = !foldl([[[""]]], cond0, acc, el, + !listconcat(acc, el)); + list> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el)); + list f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); + string id = !foldl("", f3, acc, el, acc # "\n" # el); +} + +/// Helper to create the mapping between the configuration and the mma.sp.block_scale +/// intrinsic enum value. +class MMA_SP_BLOCK_SCALE_INTR { + list>>> cond0 = + !foreach(op, NVVM_MMA_OPS.all_mma_sp_block_scale_ops, + !foreach(kind, ["mxf4", "mxf4nvf4", "mxf8f6f4"], + !foreach(scale_vec_size, ["", ".scale_1x", ".scale_2x", ".scale_4x"], + !foreach(stype, ["ue8m0", "ue4m3"], + !if(NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED.ret, + "if (m == " # op[0].m # " && n == " # op[0].n # " && k == " # op[0].k + # " && \"" # op[0].ptx_elt_type # "\" == eltypeA" + # " && \"" # op[1].ptx_elt_type # "\" == eltypeB" + # " && \"" # op[2].ptx_elt_type # "\" == eltypeC" + # " && \"" # kind # "\" == stringifyEnum(kind)" + # " && \"" # stype # "\" == stringifyEnum(blockScaleFormat)" + # " && \"" # scale_vec_size # "\" == getScaleVecSizeStr(scaleVecSize))\n" + # " return " # + MMA_SP_BLOCK_SCALE_NAME.id # ";", + "") // if supported + ) // stype + ) // scale_vec_size + ) // kind + ); // all_mma_sp_block_scale_ops + list>> f1 = !foldl([[[""]]], cond0, acc, el, + !listconcat(acc, el)); + list> f2 = !foldl([[""]], f1, acc, el, !listconcat(acc, el)); + list f3 = !foldl([""], f2, acc, el, !listconcat(acc, el)); + string id = !foldl("", f3, acc, el, acc # "\n" # el); +} + +// Common base class for MMA block scale operations (dense and sparse) +class NVVM_MmaBlockScaleBase traits = []> : + NVVM_Op { + + let results = (outs LLVM_AnyStruct:$res); + + // Common attributes shared by both dense and sparse variants + dag commonArguments = (ins + NVVM_MMAShapeAttr:$shape, + OptionalAttr:$multiplicandAPtxType, + OptionalAttr:$multiplicandBPtxType, + ScaleVecSizeAttr:$scaleVecSize, + BlockScaleFormatAttr:$blockScaleFormat, + MMABlockScaleKindAttr:$kind); + + // Common variadic operands for A, B, C matrices + dag commonVariadicOperands = (ins + Variadic:$operandA, + Variadic:$operandB, + Variadic:$operandC); + + // Common scale operands for both A and B + dag commonScaleOperands = (ins + I32:$scaleAData, + I16:$byteIdA, + I16:$threadIdA, + I32:$scaleBData, + I16:$byteIdB, + I16:$threadIdB); + + let extraClassDeclaration = !strconcat([{ + static llvm::Intrinsic::ID getIntrinsicID( + int64_t m, int64_t n, uint64_t k, + mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, + mlir::NVVM::MMATypes eltypeCEnum, + mlir::NVVM::ScaleVecSize scaleVecSize, + mlir::NVVM::BlockScaleFormat blockScaleFormat, + mlir::NVVM::MMABlockScaleKind kind) { + llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); + llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); + llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); + + auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string { + switch (svs) { + case ScaleVecSize::X1: return ".scale_1x"; + case ScaleVecSize::X2: return ".scale_2x"; + case ScaleVecSize::X4: return ".scale_4x"; + } + return ""; + }; + }], + MMA_BLOCK_SCALE_INTR<>.id, [{ + return 0; + } + + // Common declarations - implementations in NVVMDialect.cpp + MMATypes accumPtxType(); + MMATypes resultPtxType(); + + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]); + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def NVVM_MmaBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.block_scale"> { + + let summary = "cooperative matrix-multiply and accumulate with block scaling"; + + let description = [{ + The `nvvm.mma.block_scale` operation collectively performs the operation + `D = matmul(A * SF_A, B * SF_B) + C` using all threads in a warp. + + A, B, C and D are dense matrices and SF_A and SF_B are scaling factors. + Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4), + and the data type must be either ue8m0 or ue4m3. + + All the threads in the warp must execute the same `mma.block_scale` operation. + + This operation follows the same design pattern as `nvvm.mma.sync`, with additional + scaling operands for both A and B matrices. + + Example: + ```mlir + %d = nvvm.mma.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)> + ``` + }]; + + // Combine common attributes and operands + let arguments = !con(commonArguments, commonVariadicOperands, commonScaleOperands); + + let builders = [ + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + "ValueRange":$operandB, "ValueRange":$operandC, + "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA, + "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB, + "ArrayRef":$shape, + "std::optional>":$multiplicandPtxTypes, + "ScaleVecSize":$scaleVecSize, + "BlockScaleFormat":$blockScaleFormat, + "MMABlockScaleKind":$kind)> + ]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MmaBlockScaleOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + +def NVVM_MmaSpBlockScaleOp : NVVM_MmaBlockScaleBase<"mma.sp.block_scale"> { + + let summary = "cooperative sparse matrix-multiply and accumulate with block scaling"; + + let description = [{ + The `nvvm.mma.sp.block_scale` operation collectively performs the operation + `D = matmul(A_sparse * SF_A, B * SF_B) + C` using all threads in a warp. + + A is a sparse matrix, and B, C and D are dense matrices. + SF_A and SF_B are scaling factors. + Dimensions of SF_A and SF_B are based on scale vector sizes (x1, x2, x4), + and the data type must be either ue8m0 or ue4m3. + + This operation is similar to `nvvm.mma.block_scale` but with structured sparsity + in the A operand. The sparsity follows the 2:4 structured sparse pattern + where 2 out of every 4 elements are non-zero. + + All the threads in the warp must execute the same `mma.sp.block_scale` operation. + + The `sparseMetadata` operand provides the sparsity indices that indicate + which elements in the A operand are non-zero. The `sparsitySelector` + controls how the indices are distributed among threads in the warp and + should typically be 0 or 1. + + This operation follows the same design pattern as `nvvm.mma.sp.sync`, with additional + scaling operands for both A and B matrices. Note that sparse block scale operations + always use ordered metadata (sm_90+). + + Example: + ```mlir + %d = nvvm.mma.sp.block_scale A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] + sparseMetadata[%meta] selector[%sel] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<2xf16>, vector<2xf16>, vector<2xf32>) -> !llvm.struct<(f32, f32)> + ``` + }]; + + // Sparse-specific attributes and operands + dag sparseSpecificArguments = (ins + UnitAttr:$orderedMetadata); + + dag sparseSpecificOperands = (ins + I32:$sparseMetadata, + I32:$sparsitySelector); + + // Combine common and sparse-specific attributes and operands + let arguments = !con(commonArguments, sparseSpecificArguments, + commonVariadicOperands, sparseSpecificOperands, + commonScaleOperands); + + // Override extraClassDeclaration to use sparse intrinsics + let extraClassDeclaration = !strconcat([{ + static llvm::Intrinsic::ID getIntrinsicID( + int64_t m, int64_t n, uint64_t k, + mlir::NVVM::MMATypes eltypeAEnum, mlir::NVVM::MMATypes eltypeBEnum, + mlir::NVVM::MMATypes eltypeCEnum, + mlir::NVVM::ScaleVecSize scaleVecSize, + mlir::NVVM::BlockScaleFormat blockScaleFormat, + mlir::NVVM::MMABlockScaleKind kind) { + llvm::StringRef eltypeA = stringifyEnum(eltypeAEnum); + llvm::StringRef eltypeB = stringifyEnum(eltypeBEnum); + llvm::StringRef eltypeC = stringifyEnum(eltypeCEnum); + + auto getScaleVecSizeStr = [](ScaleVecSize svs) -> std::string { + switch (svs) { + case ScaleVecSize::X1: return ".scale_1x"; + case ScaleVecSize::X2: return ".scale_2x"; + case ScaleVecSize::X4: return ".scale_4x"; + } + return ""; + }; + }], + MMA_SP_BLOCK_SCALE_INTR<>.id, [{ + return 0; + } + + MMATypes accumPtxType(); + MMATypes resultPtxType(); + + static mlir::NVVM::IDArgPair + getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase& builder); + }]); + + let builders = [ + OpBuilder<(ins "Type":$resultType, "ValueRange":$operandA, + "ValueRange":$operandB, "ValueRange":$operandC, + "Value":$sparseMetadata, "Value":$sparsitySelector, + "Value":$scaleAData, "Value":$byteIdA, "Value":$threadIdA, + "Value":$scaleBData, "Value":$byteIdB, "Value":$threadIdB, + "ArrayRef":$shape, + "std::optional>":$multiplicandPtxTypes, + "ScaleVecSize":$scaleVecSize, + "BlockScaleFormat":$blockScaleFormat, + "MMABlockScaleKind":$kind)> + ]; + + string llvmBuilder = [{ + auto [id, args] = NVVM::MmaSpBlockScaleOp::getIntrinsicIDAndArgs( + *op, moduleTranslation, builder); + $res = createIntrinsicCall(builder, id, args); + }]; +} + //===----------------------------------------------------------------------===// // NVVM TMA Ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index ada4223ac12de..a06fe19b8ad1c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -694,7 +694,7 @@ void MmaOp::print(OpAsmPrinter &p) { } } std::optional inferredType = - inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2); if (inferredType) ignoreAttrNames.push_back(frag.ptxTypeAttr); } @@ -1559,6 +1559,638 @@ LogicalResult MmaSpOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// MMA Block Scale Operations - Shared Helpers +//===----------------------------------------------------------------------===// + +namespace { +// Shared structure for MMA operand fragments (A, B, C) +struct OperandFragment { + StringRef operandName; + StringRef ptxTypeAttr; + SmallVector regs; + explicit OperandFragment(StringRef name, StringRef ptxTypeName) + : operandName(name), ptxTypeAttr(ptxTypeName) {} +}; + +// Helper to print operand list in the format: name[operands] +void printOperandList(OpAsmPrinter &p, StringRef name, + ArrayRef operands) { + p << " " << name << "["; + p.printOperands(operands); + p << "]"; +} + +// Helper to parse operand list in the format: name[operands] +LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, + SmallVectorImpl ®s) { + if (parser.parseKeyword(operandName).failed()) + return failure(); + if (parser.parseOperandList(regs, + OpAsmParser::Delimiter::OptionalSquare).failed()) + return failure(); + return success(); +} + +// Helper to process operand fragments and determine which attributes can be inferred +template +void processOperandFragments(Op &op, std::array &frags, + SmallVectorImpl ®Types, + SmallVectorImpl &ignoreAttrNames) { + for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { + auto &frag = frags[fragIdx]; + auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx); + for (auto operandIdx = varOperandSpec.first; + operandIdx < varOperandSpec.first + varOperandSpec.second; + operandIdx++) { + frag.regs.push_back(op.getOperand(operandIdx)); + if (fragIdx == 0 && operandIdx == varOperandSpec.first) { + regTypes.push_back(op.getOperand(operandIdx).getType()); + } + } + if (fragIdx < 2) { + regTypes.push_back(frag.regs[0].getType()); + } + std::optional inferredType = + MmaOp::inferOperandMMAType(regTypes.back(), + /*isAccumulator=*/fragIdx >= 2); + if (inferredType) + ignoreAttrNames.push_back(frag.ptxTypeAttr); + } +} + +// Helper to parse type signature: (A_type, B_type, C_type) +LogicalResult parseMmaTypeSignature(OpAsmParser &parser, + SmallVectorImpl &operandTypes) { + if (parser.parseColon().failed() || parser.parseLParen().failed()) + return failure(); + + for (int i = 0; i < 3; i++) { + if (i > 0 && parser.parseComma().failed()) + return failure(); + Type ty; + if (parser.parseType(ty).failed()) + return failure(); + operandTypes.push_back(ty); + } + + return parser.parseRParen(); +} + +// Helper to infer and set multiplicand PTX type attributes +void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, + const SmallVectorImpl &operandTypes) { + if (!attrs.get("multiplicandAPtxType")) { + if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[0], + false)) { + attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType)); + } + } + if (!attrs.get("multiplicandBPtxType")) { + if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[1], + false)) { + attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType)); + } + } +} + +// Helper to add common block scale attributes +void addBlockScaleAttributes(OpBuilder &builder, OperationState &result, + ArrayRef shape, + ScaleVecSize scaleVecSize, + BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { + MLIRContext *ctx = builder.getContext(); + result.addAttribute("shape", + builder.getAttr(shape[0], shape[1], + shape[2])); + result.addAttribute("scaleVecSize", + ScaleVecSizeAttr::get(ctx, scaleVecSize)); + result.addAttribute("blockScaleFormat", + BlockScaleFormatAttr::get(ctx, blockScaleFormat)); + result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind)); +} + +// Helper to infer and add multiplicand PTX types to builder +void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, + ValueRange operandA, ValueRange operandB, + std::optional> multiplicandPtxTypes) { + if (multiplicandPtxTypes) { + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + } else { + if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) + result.addAttribute("multiplicandAPtxType", + MMATypesAttr::get(ctx, *res)); + if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) + result.addAttribute("multiplicandBPtxType", + MMATypesAttr::get(ctx, *res)); + } +} + +// Template helper for common accumPtxType/resultPtxType implementation +template +MMATypes inferPtxTypeFromResult(OpTy op) { + return *MmaOp::inferOperandMMAType( + cast(op.getRes().getType()).getBody()[0], + /*isAccumulator=*/true); +} +} // namespace + +//===----------------------------------------------------------------------===// +// MmaBlockScaleOp +//===----------------------------------------------------------------------===// + +void MmaBlockScaleOp::print(OpAsmPrinter &p) { + SmallVector regTypes; + std::array frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", "")}; + SmallVector ignoreAttrNames{ + mlir::NVVM::MmaBlockScaleOp::getOperandSegmentSizeAttr()}; + + processOperandFragments(*this, frags, regTypes, ignoreAttrNames); + + // Print A, B, C operands + for (const auto &frag : frags) + printOperandList(p, frag.operandName, frag.regs); + + // Print scale operands + printOperandList(p, "scaleA", + {getScaleAData(), getByteIdA(), getThreadIdA()}); + printOperandList(p, "scaleB", + {getScaleBData(), getByteIdB(), getThreadIdB()}); + + p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); + + // Print type signature + p << " : ("; + llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), + frags[1].regs[0].getType(), + frags[2].regs[0].getType()}, p); + p << ")"; + p.printArrowTypeList(TypeRange{this->getRes().getType()}); +} + +ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser, + OperationState &result) { + struct LocalOperandFragment { + std::optional elemtype; + SmallVector regs; + }; + + Builder &builder = parser.getBuilder(); + std::array frags; + NamedAttrList namedAttributes; + + // Parse A[...] B[...] C[...] + if (parseMmaOperand(parser, "A", frags[0].regs).failed() || + parseMmaOperand(parser, "B", frags[1].regs).failed() || + parseMmaOperand(parser, "C", frags[2].regs).failed()) + return failure(); + + // Parse scale operands: scaleA[...] scaleB[...] + SmallVector scaleAOperands, scaleBOperands; + if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() || + parseMmaOperand(parser, "scaleB", scaleBOperands).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse type signature + SmallVector operandTypes; + if (parseMmaTypeSignature(parser, operandTypes).failed()) + return failure(); + + // Parse result type + SmallVector resultTypes; + if (parser.parseArrowTypeList(resultTypes).failed()) + return failure(); + + // Infer element types and resolve operands + for (const auto &[idx, frag] : llvm::enumerate(frags)) { + frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx], + /*isAccumulator=*/idx >= 2); + if (parser.resolveOperands(frag.regs, operandTypes[idx], + parser.getNameLoc(), result.operands).failed()) + return failure(); + } + + // Resolve scale operands + SmallVector scaleTypes = { + builder.getI32Type(), builder.getI16Type(), builder.getI16Type() + }; + if (parser.resolveOperands(scaleAOperands, scaleTypes, + parser.getNameLoc(), result.operands).failed() || + parser.resolveOperands(scaleBOperands, scaleTypes, + parser.getNameLoc(), result.operands).failed()) + return failure(); + + // Add attributes + result.addAttributes(namedAttributes); + inferAndSetMultiplicandTypes(parser.getContext(), + result.attributes, operandTypes); + + result.addTypes(resultTypes); + result.addAttribute( + MmaBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); + return success(); +} + +void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result, + Type resultType, ValueRange operandA, + ValueRange operandB, ValueRange operandC, + Value scaleAData, Value byteIdA, Value threadIdA, + Value scaleBData, Value byteIdB, Value threadIdB, + ArrayRef shape, + std::optional> multiplicandPtxTypes, + ScaleVecSize scaleVecSize, + BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + + addBlockScaleAttributes(builder, result, shape, scaleVecSize, + blockScaleFormat, kind); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands({scaleAData, byteIdA, threadIdA, + scaleBData, byteIdB, threadIdB}); + + addInferredMultiplicandTypes(builder.getContext(), result, + operandA, operandB, multiplicandPtxTypes); + + result.addTypes(resultType); + result.addAttribute( + MmaBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size()), + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); +} + +MMATypes MmaBlockScaleOp::accumPtxType() { + return inferPtxTypeFromResult(*this); +} + +MMATypes MmaBlockScaleOp::resultPtxType() { + return inferPtxTypeFromResult(*this); +} + +NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + SmallVector args; + // Add A, B, C operands + for (Value operand : curOp.getOperandA()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandB()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandC()) + args.push_back(mt.lookupValue(operand)); + + // Add scale operands + args.push_back(mt.lookupValue(curOp.getScaleAData())); + args.push_back(mt.lookupValue(curOp.getByteIdA())); + args.push_back(mt.lookupValue(curOp.getThreadIdA())); + args.push_back(mt.lookupValue(curOp.getScaleBData())); + args.push_back(mt.lookupValue(curOp.getByteIdB())); + args.push_back(mt.lookupValue(curOp.getThreadIdB())); + + unsigned intId = MmaBlockScaleOp::getIntrinsicID( + curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(), + *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(), + curOp.accumPtxType(), + curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind()); + + return {intId, args}; +} + +LogicalResult MmaBlockScaleOp::verify() { + LogicalResult result = success(); + int m = getShape().getM(); + int n = getShape().getN(); + int k = getShape().getK(); + + if (m == 16 && n == 8 && k == 64) { + if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 || + getMultiplicandBPtxType() != NVVM::MMATypes::e2m1) + result = emitOpError( + "unsupported MMATypes attribute for mma.m16n8k64.(mxf4nvf4|mxf4)"); + if (getKind() == NVVM::MMABlockScaleKind::MXF4) { + if (getScaleVecSize() != NVVM::ScaleVecSize::X2) + result = emitOpError( + "unsupported ScaleVecSize attribute for mma.m16n8k64.mxf4"); + if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0) + result = emitOpError( + "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4"); + } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) { + if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || + (getScaleVecSize() == NVVM::ScaleVecSize::X4 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) + result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k64.mxf4nvf4"); + } else + result = emitOpError("unsupported Kind attribute for mma.m16n8k64"); + } else if (m == 16 && n == 8 && k == 32) { + if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 && + getScaleVecSize() == NVVM::ScaleVecSize::X1 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0)) + result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k32"); + } else + result = emitOpError("unsupported Geom for mma with block scaling"); + return result; +} + +//===----------------------------------------------------------------------===// +// MmaSpBlockScaleOp +//===----------------------------------------------------------------------===// + +void MmaSpBlockScaleOp::print(OpAsmPrinter &p) { + SmallVector regTypes; + std::array frags{ + OperandFragment("A", getMultiplicandAPtxTypeAttrName()), + OperandFragment("B", getMultiplicandBPtxTypeAttrName()), + OperandFragment("C", "")}; + SmallVector ignoreAttrNames{ + mlir::NVVM::MmaSpBlockScaleOp::getOperandSegmentSizeAttr()}; + + processOperandFragments(*this, frags, regTypes, ignoreAttrNames); + + // Print A, B, C operands + for (const auto &frag : frags) + printOperandList(p, frag.operandName, frag.regs); + + // Print sparse-specific operands + printOperandList(p, "sparseMetadata", {getSparseMetadata()}); + printOperandList(p, "selector", {getSparsitySelector()}); + + // Print scale operands + printOperandList(p, "scaleA", + {getScaleAData(), getByteIdA(), getThreadIdA()}); + printOperandList(p, "scaleB", + {getScaleBData(), getByteIdB(), getThreadIdB()}); + + p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames); + + // Print type signature + p << " : ("; + llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), + frags[1].regs[0].getType(), + frags[2].regs[0].getType()}, p); + p << ")"; + p.printArrowTypeList(TypeRange{this->getRes().getType()}); +} + +ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser, + OperationState &result) { + struct LocalOperandFragment { + std::optional elemtype; + SmallVector regs; + }; + + Builder &builder = parser.getBuilder(); + std::array frags; + NamedAttrList namedAttributes; + + // Parse A[...] B[...] C[...] + if (parseMmaOperand(parser, "A", frags[0].regs).failed() || + parseMmaOperand(parser, "B", frags[1].regs).failed() || + parseMmaOperand(parser, "C", frags[2].regs).failed()) + return failure(); + + // Parse sparse-specific operands + SmallVector metadataOperands, selectorOperands; + if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() || + parseMmaOperand(parser, "selector", selectorOperands).failed()) + return failure(); + + // Parse scale operands + SmallVector scaleAOperands, scaleBOperands; + if (parseMmaOperand(parser, "scaleA", scaleAOperands).failed() || + parseMmaOperand(parser, "scaleB", scaleBOperands).failed()) + return failure(); + + if (parser.parseOptionalAttrDict(namedAttributes).failed()) + return failure(); + + // Parse type signature + SmallVector operandTypes; + if (parseMmaTypeSignature(parser, operandTypes).failed()) + return failure(); + + // Parse result type + SmallVector resultTypes; + if (parser.parseArrowTypeList(resultTypes).failed()) + return failure(); + + // Infer element types and resolve operands + for (const auto &[idx, frag] : llvm::enumerate(frags)) { + frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx], + /*isAccumulator=*/idx >= 2); + if (parser.resolveOperands(frag.regs, operandTypes[idx], + parser.getNameLoc(), result.operands).failed()) + return failure(); + } + + // Resolve sparse metadata and selector + Type i32Type = builder.getI32Type(); + if (parser.resolveOperands(metadataOperands, i32Type, + parser.getNameLoc(), result.operands).failed() || + parser.resolveOperands(selectorOperands, i32Type, + parser.getNameLoc(), result.operands).failed()) + return failure(); + + // Resolve scale operands + SmallVector scaleTypes = { + i32Type, builder.getI16Type(), builder.getI16Type() + }; + if (parser.resolveOperands(scaleAOperands, scaleTypes, + parser.getNameLoc(), result.operands).failed() || + parser.resolveOperands(scaleBOperands, scaleTypes, + parser.getNameLoc(), result.operands).failed()) + return failure(); + + // Add attributes + result.addAttributes(namedAttributes); + inferAndSetMultiplicandTypes(parser.getContext(), + result.attributes, operandTypes); + + // orderedMetadata is mandatory + if (!result.attributes.get("orderedMetadata")) + result.addAttribute("orderedMetadata", builder.getUnitAttr()); + + result.addTypes(resultTypes); + result.addAttribute( + MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + 1, // sparseMetadata + 1, // sparsitySelector + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); + return success(); +} + +void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result, + Type resultType, ValueRange operandA, + ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, + Value scaleAData, Value byteIdA, Value threadIdA, + Value scaleBData, Value byteIdB, Value threadIdB, + ArrayRef shape, + std::optional> multiplicandPtxTypes, + ScaleVecSize scaleVecSize, + BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { + assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); + + addBlockScaleAttributes(builder, result, shape, scaleVecSize, + blockScaleFormat, kind); + result.addAttribute("orderedMetadata", builder.getUnitAttr()); + + result.addOperands(operandA); + result.addOperands(operandB); + result.addOperands(operandC); + result.addOperands({sparseMetadata, sparsitySelector, + scaleAData, byteIdA, threadIdA, + scaleBData, byteIdB, threadIdB}); + + addInferredMultiplicandTypes(builder.getContext(), result, + operandA, operandB, multiplicandPtxTypes); + + result.addTypes(resultType); + result.addAttribute( + MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size()), + 1, // sparseMetadata + 1, // sparsitySelector + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); +} + +MMATypes MmaSpBlockScaleOp::accumPtxType() { + return inferPtxTypeFromResult(*this); +} + +MMATypes MmaSpBlockScaleOp::resultPtxType() { + return inferPtxTypeFromResult(*this); +} + +NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto curOp = cast(op); + + SmallVector args; + // Add A, B, C operands + for (Value operand : curOp.getOperandA()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandB()) + args.push_back(mt.lookupValue(operand)); + for (Value operand : curOp.getOperandC()) + args.push_back(mt.lookupValue(operand)); + + // Add sparse metadata and selector + args.push_back(mt.lookupValue(curOp.getSparseMetadata())); + args.push_back(mt.lookupValue(curOp.getSparsitySelector())); + + // Add scale operands + args.push_back(mt.lookupValue(curOp.getScaleAData())); + args.push_back(mt.lookupValue(curOp.getByteIdA())); + args.push_back(mt.lookupValue(curOp.getThreadIdA())); + args.push_back(mt.lookupValue(curOp.getScaleBData())); + args.push_back(mt.lookupValue(curOp.getByteIdB())); + args.push_back(mt.lookupValue(curOp.getThreadIdB())); + + unsigned intId = MmaSpBlockScaleOp::getIntrinsicID( + curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(), + *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(), + curOp.accumPtxType(), + curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind()); + + return {intId, args}; +} + +LogicalResult MmaSpBlockScaleOp::verify() { + // Check that orderedMetadata is present + if (!getOrderedMetadata()) { + return emitOpError("'orderedMetadata' attribute is mandatory"); + } + + LogicalResult result = success(); + int m = getShape().getM(); + int n = getShape().getN(); + int k = getShape().getK(); + + if (m == 16 && n == 8 && k == 128) { + if (getMultiplicandAPtxType() != NVVM::MMATypes::e2m1 || + getMultiplicandBPtxType() != NVVM::MMATypes::e2m1) + result = emitOpError( + "unsupported MMATypes attribute for mma.m16n8k128.(mxf4nvf4|mxf4)"); + if (getKind() == NVVM::MMABlockScaleKind::MXF4) { + if (getScaleVecSize() != NVVM::ScaleVecSize::X2) + result = emitOpError( + "unsupported ScaleVecSize attribute for mma.m16n8k128.mxf4"); + if (getBlockScaleFormat() != NVVM::BlockScaleFormat::UE8M0) + result = emitOpError( + "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4"); + } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) { + if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || + (getScaleVecSize() == NVVM::ScaleVecSize::X4 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) + result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k128.mxf4nvf4"); + } else + result = emitOpError("unsupported Kind attribute for mma.m16n8k128"); + } else if (m == 16 && n == 8 && k == 64) { + if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 && + getScaleVecSize() == NVVM::ScaleVecSize::X1 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0)) + result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k64"); + } else + result = emitOpError("unsupported Geom for sparse mma with block scaling"); + return result; +} + LogicalResult ShflOp::verify() { auto returnStructType = llvm::dyn_cast(getType()); diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir new file mode 100644 index 0000000000000..fbd0203d19904 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir @@ -0,0 +1,525 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for all dense MMA block scale operations in the NVVM dialect +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-with-block-scaling +// +// MMA block scale operations perform matrix multiply-accumulate with block scaling: +// D = matmul(A * SF_A, B * SF_B) + C +// where SF_A and SF_B are scaling factors with dimensions based on scale vector size. + +// ============================================================================= +// MXF8F6F4 Block Scale MMA Operations (m16n8k32) - All Type Combinations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2 +func.func @nvvm_mxf8f6f4_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4 Block Scale MMA Operations (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4_blockscale_mma +func.func @nvvm_mxf4_blockscale_mma(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4NVF4 Block Scale MMA Operations (m16n8k64) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue8m0 +func.func @nvvm_mxf4nvf4_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf4nvf4_blockscale_mma_ue4m3 +func.func @nvvm_mxf4nvf4_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<2xi32>, %c: vector<4xf32>, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.block_scale A[%a] B[%b] C[%c] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind} + : (vector<4xi32>, vector<2xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} diff --git a/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir new file mode 100644 index 0000000000000..2e72012bcf722 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir @@ -0,0 +1,637 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// This file contains tests for all sparse MMA block scale operations in the NVVM dialect +// Based on PTX ISA documentation: +// https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-with-block-scaling +// +// Sparse MMA block scale operations perform matrix multiply-accumulate with block scaling +// on sparse matrices: D = matmul(A * SF_A, B * SF_B) + C +// where A follows 2:4 structured sparsity and SF_A, SF_B are scaling factors. + +// ============================================================================= +// MXF8F6F4 Sparse Block Scale MMA Operations (m16n8k64) - All Type Combinations +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m1_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e2m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e3m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e4m3_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m1(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e2m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e3m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2 +func.func @nvvm_mxf8f6f4_sp_blockscale_mma_e5m2_e5m2(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4 Sparse Block Scale MMA Operations (m16n8k128) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4_sp_blockscale_mma +func.func @nvvm_mxf4_sp_blockscale_mma(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// ============================================================================= +// MXF4NVF4 Sparse Block Scale MMA Operations (m16n8k128) +// ============================================================================= + +// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0 +func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue8m0(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} + +// CHECK-LABEL: @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3 +func.func @nvvm_mxf4nvf4_sp_blockscale_mma_ue4m3(%a: vector<4xi32>, %b: vector<4xi32>, %c: vector<4xf32>, + %sparseMetadata: i32, %sparsitySelector: i32, + %scaleAData: i32, %byteIdA: i16, %threadIdA: i16, + %scaleBData: i32, %byteIdB: i16, %threadIdB: i16) { + // CHECK: nvvm.mma.sp.block_scale A[{{.*}}] B[{{.*}}] C[{{.*}}] sparseMetadata[{{.*}}] selector[{{.*}}] scaleA[{{.*}}, {{.*}}, {{.*}}] scaleB[{{.*}}, {{.*}}, {{.*}}] + %0 = nvvm.mma.sp.block_scale A[%a] B[%b] C[%c] + sparseMetadata[%sparseMetadata] + selector[%sparsitySelector] + scaleA[%scaleAData, %byteIdA, %threadIdA] + scaleB[%scaleBData, %byteIdB, %threadIdB] + {shape = #nvvm.shape, + multiplicandAPtxType = #nvvm.mma_type, + multiplicandBPtxType = #nvvm.mma_type, + scaleVecSize = #nvvm.scale_vec_size, + blockScaleFormat = #nvvm.block_scale_format, + kind = #nvvm.block_scale_kind, + orderedMetadata} + : (vector<4xi32>, vector<4xi32>, vector<4xf32>) -> !llvm.struct<(vector<4xf32>)> + return +} From c664fd01cfa2bf91973c5d4fbdd6fc935f4bdd6f Mon Sep 17 00:00:00 2001 From: Kirill Vedernikov Date: Wed, 3 Dec 2025 23:30:30 +0100 Subject: [PATCH 2/3] [MLIR] Fixes for code formatting --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 323 +++++++++++---------- 1 file changed, 168 insertions(+), 155 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a06fe19b8ad1c..2d40ecf6e5ee6 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -693,8 +693,8 @@ void MmaOp::print(OpAsmPrinter &p) { regTypes.push_back(this->getOperand(operandIdx).getType()); } } - std::optional inferredType = - MmaOp::inferOperandMMAType(regTypes.back(), /*isAccumulator=*/fragIdx >= 2); + std::optional inferredType = MmaOp::inferOperandMMAType( + regTypes.back(), /*isAccumulator=*/fragIdx >= 2); if (inferredType) ignoreAttrNames.push_back(frag.ptxTypeAttr); } @@ -1582,21 +1582,23 @@ void printOperandList(OpAsmPrinter &p, StringRef name, } // Helper to parse operand list in the format: name[operands] -LogicalResult parseMmaOperand(OpAsmParser &parser, StringRef operandName, - SmallVectorImpl ®s) { +LogicalResult +parseMmaOperand(OpAsmParser &parser, StringRef operandName, + SmallVectorImpl ®s) { if (parser.parseKeyword(operandName).failed()) return failure(); - if (parser.parseOperandList(regs, - OpAsmParser::Delimiter::OptionalSquare).failed()) + if (parser.parseOperandList(regs, OpAsmParser::Delimiter::OptionalSquare) + .failed()) return failure(); return success(); } -// Helper to process operand fragments and determine which attributes can be inferred +// Helper to process operand fragments and determine which attributes can be +// inferred template void processOperandFragments(Op &op, std::array &frags, - SmallVectorImpl ®Types, - SmallVectorImpl &ignoreAttrNames) { + SmallVectorImpl ®Types, + SmallVectorImpl &ignoreAttrNames) { for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) { auto &frag = frags[fragIdx]; auto varOperandSpec = op.getODSOperandIndexAndLength(fragIdx); @@ -1639,16 +1641,16 @@ LogicalResult parseMmaTypeSignature(OpAsmParser &parser, // Helper to infer and set multiplicand PTX type attributes void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, - const SmallVectorImpl &operandTypes) { + const SmallVectorImpl &operandTypes) { if (!attrs.get("multiplicandAPtxType")) { - if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[0], - false)) { + if (auto inferredType = + MmaOp::inferOperandMMAType(operandTypes[0], false)) { attrs.set("multiplicandAPtxType", MMATypesAttr::get(ctx, *inferredType)); } } if (!attrs.get("multiplicandBPtxType")) { - if (auto inferredType = MmaOp::inferOperandMMAType(operandTypes[1], - false)) { + if (auto inferredType = + MmaOp::inferOperandMMAType(operandTypes[1], false)) { attrs.set("multiplicandBPtxType", MMATypesAttr::get(ctx, *inferredType)); } } @@ -1656,37 +1658,34 @@ void inferAndSetMultiplicandTypes(MLIRContext *ctx, NamedAttrList &attrs, // Helper to add common block scale attributes void addBlockScaleAttributes(OpBuilder &builder, OperationState &result, - ArrayRef shape, - ScaleVecSize scaleVecSize, + ArrayRef shape, ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, MMABlockScaleKind kind) { MLIRContext *ctx = builder.getContext(); - result.addAttribute("shape", - builder.getAttr(shape[0], shape[1], - shape[2])); - result.addAttribute("scaleVecSize", - ScaleVecSizeAttr::get(ctx, scaleVecSize)); + result.addAttribute( + "shape", builder.getAttr(shape[0], shape[1], shape[2])); + result.addAttribute( + "scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize)); result.addAttribute("blockScaleFormat", BlockScaleFormatAttr::get(ctx, blockScaleFormat)); result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind)); } // Helper to infer and add multiplicand PTX types to builder -void addInferredMultiplicandTypes(MLIRContext *ctx, OperationState &result, - ValueRange operandA, ValueRange operandB, - std::optional> multiplicandPtxTypes) { +void addInferredMultiplicandTypes( + MLIRContext *ctx, OperationState &result, ValueRange operandA, + ValueRange operandB, + std::optional> multiplicandPtxTypes) { if (multiplicandPtxTypes) { result.addAttribute("multiplicandAPtxType", - MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0])); result.addAttribute("multiplicandBPtxType", - MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); + MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1])); } else { if (auto res = MmaOp::inferOperandMMAType(operandA[0].getType(), false)) - result.addAttribute("multiplicandAPtxType", - MMATypesAttr::get(ctx, *res)); + result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res)); if (auto res = MmaOp::inferOperandMMAType(operandB[0].getType(), false)) - result.addAttribute("multiplicandBPtxType", - MMATypesAttr::get(ctx, *res)); + result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res)); } } @@ -1730,7 +1729,8 @@ void MmaBlockScaleOp::print(OpAsmPrinter &p) { p << " : ("; llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), frags[1].regs[0].getType(), - frags[2].regs[0].getType()}, p); + frags[2].regs[0].getType()}, + p); p << ")"; p.printArrowTypeList(TypeRange{this->getRes().getType()}); } @@ -1775,52 +1775,55 @@ ParseResult MmaBlockScaleOp::parse(OpAsmParser &parser, for (const auto &[idx, frag] : llvm::enumerate(frags)) { frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx], /*isAccumulator=*/idx >= 2); - if (parser.resolveOperands(frag.regs, operandTypes[idx], - parser.getNameLoc(), result.operands).failed()) + if (parser + .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(), + result.operands) + .failed()) return failure(); } // Resolve scale operands - SmallVector scaleTypes = { - builder.getI32Type(), builder.getI16Type(), builder.getI16Type() - }; - if (parser.resolveOperands(scaleAOperands, scaleTypes, - parser.getNameLoc(), result.operands).failed() || - parser.resolveOperands(scaleBOperands, scaleTypes, - parser.getNameLoc(), result.operands).failed()) + SmallVector scaleTypes = {builder.getI32Type(), builder.getI16Type(), + builder.getI16Type()}; + if (parser + .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed() || + parser + .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed()) return failure(); // Add attributes result.addAttributes(namedAttributes); - inferAndSetMultiplicandTypes(parser.getContext(), - result.attributes, operandTypes); + inferAndSetMultiplicandTypes(parser.getContext(), result.attributes, + operandTypes); result.addTypes(resultTypes); - result.addAttribute( - MmaBlockScaleOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(frags[0].regs.size()), - static_cast(frags[1].regs.size()), - static_cast(frags[2].regs.size()), - 1, // scaleAData - 1, // byteIdA - 1, // threadIdA - 1, // scaleBData - 1, // byteIdB - 1 // threadIdB - })); + result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); return success(); } -void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result, - Type resultType, ValueRange operandA, - ValueRange operandB, ValueRange operandC, - Value scaleAData, Value byteIdA, Value threadIdA, - Value scaleBData, Value byteIdB, Value threadIdB, - ArrayRef shape, - std::optional> multiplicandPtxTypes, - ScaleVecSize scaleVecSize, - BlockScaleFormat blockScaleFormat, - MMABlockScaleKind kind) { +void MmaBlockScaleOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value scaleAData, Value byteIdA, Value threadIdA, Value scaleBData, + Value byteIdB, Value threadIdB, ArrayRef shape, + std::optional> multiplicandPtxTypes, + ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); addBlockScaleAttributes(builder, result, shape, scaleVecSize, @@ -1829,25 +1832,25 @@ void MmaBlockScaleOp::build(OpBuilder &builder, OperationState &result, result.addOperands(operandA); result.addOperands(operandB); result.addOperands(operandC); - result.addOperands({scaleAData, byteIdA, threadIdA, - scaleBData, byteIdB, threadIdB}); + result.addOperands( + {scaleAData, byteIdA, threadIdA, scaleBData, byteIdB, threadIdB}); - addInferredMultiplicandTypes(builder.getContext(), result, - operandA, operandB, multiplicandPtxTypes); + addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB, + multiplicandPtxTypes); result.addTypes(resultType); - result.addAttribute( - MmaBlockScaleOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(operandA.size()), - static_cast(operandB.size()), - static_cast(operandC.size()), - 1, // scaleAData - 1, // byteIdA - 1, // threadIdA - 1, // scaleBData - 1, // byteIdB - 1 // threadIdB - })); + result.addAttribute(MmaBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size()), + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); } MMATypes MmaBlockScaleOp::accumPtxType() { @@ -1882,8 +1885,8 @@ NVVM::IDArgPair MmaBlockScaleOp::getIntrinsicIDAndArgs( unsigned intId = MmaBlockScaleOp::getIntrinsicID( curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(), *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(), - curOp.accumPtxType(), - curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind()); + curOp.accumPtxType(), curOp.getScaleVecSize(), + curOp.getBlockScaleFormat(), curOp.getKind()); return {intId, args}; } @@ -1908,9 +1911,9 @@ LogicalResult MmaBlockScaleOp::verify() { "unsupported BlockScaleFormat attribute for mma.m16n8k64.mxf4"); } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) { if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 && - getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || - (getScaleVecSize() == NVVM::ScaleVecSize::X4 && - getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || + (getScaleVecSize() == NVVM::ScaleVecSize::X4 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat " "attributes for mma.m16n8k64.mxf4nvf4"); } else @@ -1919,8 +1922,9 @@ LogicalResult MmaBlockScaleOp::verify() { if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 && getScaleVecSize() == NVVM::ScaleVecSize::X1 && getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0)) - result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " - "attributes for mma.m16n8k32"); + result = + emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k32"); } else result = emitOpError("unsupported Geom for mma with block scaling"); return result; @@ -1961,7 +1965,8 @@ void MmaSpBlockScaleOp::print(OpAsmPrinter &p) { p << " : ("; llvm::interleaveComma(SmallVector{frags[0].regs[0].getType(), frags[1].regs[0].getType(), - frags[2].regs[0].getType()}, p); + frags[2].regs[0].getType()}, + p); p << ")"; p.printArrowTypeList(TypeRange{this->getRes().getType()}); } @@ -1984,7 +1989,8 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser, return failure(); // Parse sparse-specific operands - SmallVector metadataOperands, selectorOperands; + SmallVector metadataOperands, + selectorOperands; if (parseMmaOperand(parser, "sparseMetadata", metadataOperands).failed() || parseMmaOperand(parser, "selector", selectorOperands).failed()) return failure(); @@ -2012,67 +2018,74 @@ ParseResult MmaSpBlockScaleOp::parse(OpAsmParser &parser, for (const auto &[idx, frag] : llvm::enumerate(frags)) { frag.elemtype = MmaOp::inferOperandMMAType(operandTypes[idx], /*isAccumulator=*/idx >= 2); - if (parser.resolveOperands(frag.regs, operandTypes[idx], - parser.getNameLoc(), result.operands).failed()) + if (parser + .resolveOperands(frag.regs, operandTypes[idx], parser.getNameLoc(), + result.operands) + .failed()) return failure(); } // Resolve sparse metadata and selector Type i32Type = builder.getI32Type(); - if (parser.resolveOperands(metadataOperands, i32Type, - parser.getNameLoc(), result.operands).failed() || - parser.resolveOperands(selectorOperands, i32Type, - parser.getNameLoc(), result.operands).failed()) + if (parser + .resolveOperands(metadataOperands, i32Type, parser.getNameLoc(), + result.operands) + .failed() || + parser + .resolveOperands(selectorOperands, i32Type, parser.getNameLoc(), + result.operands) + .failed()) return failure(); // Resolve scale operands - SmallVector scaleTypes = { - i32Type, builder.getI16Type(), builder.getI16Type() - }; - if (parser.resolveOperands(scaleAOperands, scaleTypes, - parser.getNameLoc(), result.operands).failed() || - parser.resolveOperands(scaleBOperands, scaleTypes, - parser.getNameLoc(), result.operands).failed()) + SmallVector scaleTypes = {i32Type, builder.getI16Type(), + builder.getI16Type()}; + if (parser + .resolveOperands(scaleAOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed() || + parser + .resolveOperands(scaleBOperands, scaleTypes, parser.getNameLoc(), + result.operands) + .failed()) return failure(); // Add attributes result.addAttributes(namedAttributes); - inferAndSetMultiplicandTypes(parser.getContext(), - result.attributes, operandTypes); + inferAndSetMultiplicandTypes(parser.getContext(), result.attributes, + operandTypes); // orderedMetadata is mandatory if (!result.attributes.get("orderedMetadata")) result.addAttribute("orderedMetadata", builder.getUnitAttr()); result.addTypes(resultTypes); - result.addAttribute( - MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(frags[0].regs.size()), - static_cast(frags[1].regs.size()), - static_cast(frags[2].regs.size()), - 1, // sparseMetadata - 1, // sparsitySelector - 1, // scaleAData - 1, // byteIdA - 1, // threadIdA - 1, // scaleBData - 1, // byteIdB - 1 // threadIdB - })); + result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(frags[0].regs.size()), + static_cast(frags[1].regs.size()), + static_cast(frags[2].regs.size()), + 1, // sparseMetadata + 1, // sparsitySelector + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); return success(); } -void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result, - Type resultType, ValueRange operandA, - ValueRange operandB, ValueRange operandC, - Value sparseMetadata, Value sparsitySelector, - Value scaleAData, Value byteIdA, Value threadIdA, - Value scaleBData, Value byteIdB, Value threadIdB, - ArrayRef shape, - std::optional> multiplicandPtxTypes, - ScaleVecSize scaleVecSize, - BlockScaleFormat blockScaleFormat, - MMABlockScaleKind kind) { +void MmaSpBlockScaleOp::build( + OpBuilder &builder, OperationState &result, Type resultType, + ValueRange operandA, ValueRange operandB, ValueRange operandC, + Value sparseMetadata, Value sparsitySelector, Value scaleAData, + Value byteIdA, Value threadIdA, Value scaleBData, Value byteIdB, + Value threadIdB, ArrayRef shape, + std::optional> multiplicandPtxTypes, + ScaleVecSize scaleVecSize, BlockScaleFormat blockScaleFormat, + MMABlockScaleKind kind) { assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)"); addBlockScaleAttributes(builder, result, shape, scaleVecSize, @@ -2082,28 +2095,27 @@ void MmaSpBlockScaleOp::build(OpBuilder &builder, OperationState &result, result.addOperands(operandA); result.addOperands(operandB); result.addOperands(operandC); - result.addOperands({sparseMetadata, sparsitySelector, - scaleAData, byteIdA, threadIdA, - scaleBData, byteIdB, threadIdB}); + result.addOperands({sparseMetadata, sparsitySelector, scaleAData, byteIdA, + threadIdA, scaleBData, byteIdB, threadIdB}); - addInferredMultiplicandTypes(builder.getContext(), result, - operandA, operandB, multiplicandPtxTypes); + addInferredMultiplicandTypes(builder.getContext(), result, operandA, operandB, + multiplicandPtxTypes); result.addTypes(resultType); - result.addAttribute( - MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), - builder.getDenseI32ArrayAttr({static_cast(operandA.size()), - static_cast(operandB.size()), - static_cast(operandC.size()), - 1, // sparseMetadata - 1, // sparsitySelector - 1, // scaleAData - 1, // byteIdA - 1, // threadIdA - 1, // scaleBData - 1, // byteIdB - 1 // threadIdB - })); + result.addAttribute(MmaSpBlockScaleOp::getOperandSegmentSizeAttr(), + builder.getDenseI32ArrayAttr({ + static_cast(operandA.size()), + static_cast(operandB.size()), + static_cast(operandC.size()), + 1, // sparseMetadata + 1, // sparsitySelector + 1, // scaleAData + 1, // byteIdA + 1, // threadIdA + 1, // scaleBData + 1, // byteIdB + 1 // threadIdB + })); } MMATypes MmaSpBlockScaleOp::accumPtxType() { @@ -2142,8 +2154,8 @@ NVVM::IDArgPair MmaSpBlockScaleOp::getIntrinsicIDAndArgs( unsigned intId = MmaSpBlockScaleOp::getIntrinsicID( curOp.getShape().getM(), curOp.getShape().getN(), curOp.getShape().getK(), *curOp.getMultiplicandAPtxType(), *curOp.getMultiplicandBPtxType(), - curOp.accumPtxType(), - curOp.getScaleVecSize(), curOp.getBlockScaleFormat(), curOp.getKind()); + curOp.accumPtxType(), curOp.getScaleVecSize(), + curOp.getBlockScaleFormat(), curOp.getKind()); return {intId, args}; } @@ -2173,9 +2185,9 @@ LogicalResult MmaSpBlockScaleOp::verify() { "unsupported BlockScaleFormat attribute for mma.m16n8k128.mxf4"); } else if (getKind() == NVVM::MMABlockScaleKind::MXF4NVF4) { if (!((getScaleVecSize() == NVVM::ScaleVecSize::X2 && - getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || - (getScaleVecSize() == NVVM::ScaleVecSize::X4 && - getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0) || + (getScaleVecSize() == NVVM::ScaleVecSize::X4 && + getBlockScaleFormat() == NVVM::BlockScaleFormat::UE4M3))) result = emitOpError("unsupported ScaleVecSize and BlockScaleFormat " "attributes for mma.m16n8k128.mxf4nvf4"); } else @@ -2184,8 +2196,9 @@ LogicalResult MmaSpBlockScaleOp::verify() { if (!(getKind() == NVVM::MMABlockScaleKind::MXF8F6F4 && getScaleVecSize() == NVVM::ScaleVecSize::X1 && getBlockScaleFormat() == NVVM::BlockScaleFormat::UE8M0)) - result = emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " - "attributes for mma.m16n8k64"); + result = + emitOpError("unsupported Kind, ScaleVecSize and BlockScaleFormat " + "attributes for mma.m16n8k64"); } else result = emitOpError("unsupported Geom for sparse mma with block scaling"); return result; From 765647ab594549bd2678994c0a20fc7692e3432d Mon Sep 17 00:00:00 2001 From: Kirill Vedernikov Date: Wed, 3 Dec 2025 23:31:42 +0100 Subject: [PATCH 3/3] [MLIR] One more fix for code formatting --- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 2d40ecf6e5ee6..3387c69025e20 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1664,8 +1664,7 @@ void addBlockScaleAttributes(OpBuilder &builder, OperationState &result, MLIRContext *ctx = builder.getContext(); result.addAttribute( "shape", builder.getAttr(shape[0], shape[1], shape[2])); - result.addAttribute( - "scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize)); + result.addAttribute("scaleVecSize", ScaleVecSizeAttr::get(ctx, scaleVecSize)); result.addAttribute("blockScaleFormat", BlockScaleFormatAttr::get(ctx, blockScaleFormat)); result.addAttribute("kind", MMABlockScaleKindAttr::get(ctx, kind));