-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[NVPTX] Support for dense and sparse MMA intrinsics with block scaling. #163561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-backend-nvptx @llvm/pr-subscribers-llvm-ir Author: Kirill Vedernikov (kvederni) ChangesThis change adds dense and sparse MMA intrinsics with block scaling. The implementation is based on PTX ISA version 9.0. Tests for new intrinsics are added for PTX 8.7 and SM 120a and are generated by Patch is 29.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/163561.diff 3 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 3af1750ffcf3f..6256baa50a1c6 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -277,6 +277,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
+ // mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
+ !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
+
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All other supported geometries use the same fragment format for f32 and
// f16, so we only need to consider {fragment, type}.
@@ -520,6 +524,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
# signature;
}
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record = "int_nvvm_mma_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +549,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
# signature;
}
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B,
+ WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record = "int_nvvm_mma_sp_ordered_metadata_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
string intr = "llvm.nvvm.ldmatrix.sync.aligned"
# "." # Frag.geom
@@ -672,6 +701,18 @@ class NVVM_MMA_OPS {
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
int_mma_ops, subint_mma_ops, bit_mma_ops);
+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+ ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
+ >.ret;
+
+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+ ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
+ >.ret;
+
+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+ mxf4_mma_ops, mxf8f6f4_mma_ops);
+
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
["m16n8k16", "m16n8k32"],
["bf16"], [], ["f32"], [], true>.ret;
@@ -696,6 +737,18 @@ class NVVM_MMA_OPS {
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);
+ // combines available geoms and types for mxf4 and mxf4nvf4 kinds
+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+ ["m16n8k128"],
+ ["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -889,6 +942,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
);
}
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x")),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_2x"),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_4x"),
+ !eq(stype, "ue4m3")) : true,
+ !and(!eq(geom, "m16n8k32"),
+ !eq(kind, "mxf8f6f4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x")),
+ !eq(stype, "ue8m0")) : true,
+ true: false
+ );
+}
+
// Returns true if the fragment is valid for ldmatrix ops is supported;
// false otherwise.
// E.g.
@@ -987,6 +1066,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
}
+// Returns true if this combination of kind/scale_vec_size/stype
+// for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x"))): true,
+
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale_vec_size, ".scale_2x")): true,
+
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue4m3"),
+ !eq(scale_vec_size, ".scale_4x")): true,
+
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf8f6f4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x"))): true,
+
+ // All other are NOT OK.
+ true: false
+ );
+}
+
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -2340,6 +2464,31 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a
+class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs,
+ [
+ llvm_i32_ty, // scale-a-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+ llvm_i32_ty, // scale-b-data,
+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+ ]),
+ [IntrNoMem, IntrNoCallback]>;
+
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+} // kind
+
// MMA.SP
class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
@@ -2387,6 +2536,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
} // kind
} // metadata
+// MMA.SP BLOCK SCALE
+class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs,
+ [
+ llvm_i32_ty, // metadata
+ llvm_i32_ty, // sparsity selector
+ llvm_i32_ty, // scale-a-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+ llvm_i32_ty, // scale-b-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+ ])> {
+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+
+ // The range [0;num_threads) is for the sparsity selector that indicates the threads
+ // which contribute metadata.
+ // According to PTX ISA 9.0, the sparsity selector is always 0
+ // for sparse MMA block scale instructions
+ int num_threads = 1;
+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+ Range<ArgIndex<pos>, 0, num_threads>];
+}
+
+// According to PTX ISA 9.0
+// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+} // kind
+
// LDMATRIX
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<Frag.regs, [llvm_anyptr_ty],
@@ -2984,4 +3172,4 @@ foreach sp = [0, 1] in {
}
}
-} // let TargetPrefix = "nvvm"
\ No newline at end of file
+} // let TargetPrefix = "nvvm"
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 22cf3a7eef2c1..febb9b41c84f3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4886,6 +4886,67 @@ defset list<WMMA_INSTR> MMAs = {
} // defset
}
+// MMA.block_scale
+class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$scale_a, B16:$byte_id_a,
+ B16:$thread_id_a, B32:$scale_b,
+ B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<FragA.Predicates> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = !interleave([FragD.ptx_elt_type,
+ FragA.ptx_elt_type,
+ FragB.ptx_elt_type,
+ FragC.ptx_elt_type], ".");
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # "." # TypeList
+ # "." # SType # " \n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[1], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[2], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[3], "mma.block_scale", "", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
// MMA SP
class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
@@ -4942,6 +5003,72 @@ defset list<WMMA_INSTR> MMA_SPs = {
} // defset
}
+// MMA SP BLOCK SCALE
+class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$metadata, i32imm:$selector,
+ B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
+ B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<!listconcat(FragA.Predicates,
+ FragB.Predicates,
+ FragC.Predicates,
+ FragD.Predicates)> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sp::ordered_metadata.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # TypeList
+ # "." # SType # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$metadata" # ",\n\t\t"
+ # "$selector" # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
@@ -5023,7 +5150,8 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
+foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs,
+ STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 8427ae4ad72da..81c78219075f3 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -131,7 +131,7 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
"m16n8k64:b:e5m2": 4,
"m16n8k64:b:e3m2": 4,
"m16n8k64:b:e2m3": 4,
- "m16n8k64:b:e2m1": 4,
+ "m16n8k64:b:e2m1": 4 if is_mma_sparse else 2,
"m16n8k64:c:f16": 2,
"m16n8k64:c:f32": 4,
"m16n8k64:d:f16": 2,
@@ -1131,6 +1131,163 @@ def gen_mma_tests():
return generated_items
+def get_mma_block_scale_ops():
+ return (
+ make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], [])
+ + make_mma_ops(
+ ["m16n8k32"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"],
+ [],
+ )
+ )
+
+
+def is_mma_block_scale_geom_supported(geom):
+ # geometries for FP.
+ if geom in [
+ "m16n8k32",
+ "m16n8k64",
+ ]:
+ return True
+ raise ValueError(f"Unexpected MMA block scale geometry: {geom}")
+
+
+def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ if not (
+ is_type_supported(op.a.mma_type.ptx_type)
+ and is_mma_block_scale_geom_supported(op.a.geom)
+ ):
+ return False
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::2X"]
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4nvf4"
+ and stype == "ue8m0"
+ and scale_vec_size == ".scale_vec::2X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4nvf4"
+ and stype == "ue4m3"
+ and scale_vec_size == ".scale_vec::4X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k32"
+ and kind == "mxf8f6f4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::1X"]
+ ):
+ return True
+
+ return False
+
+
+def common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+ mma_block_scale_template = """
+declare ${ret_ty} @${intrinsic}(
+ ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(
+ ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+ %r = call ${ret_ty} @${intrinsic}(
+ ${args});
+ ret ${ret_ty} %r;
+}
+"""
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_...
[truncated]
|
|
✅ With the latest revision this PR passed the Python code formatter. |
|
Please also make sure that the ptxas tests also still work with the ptxas from cuda-12.8. |
This change adds dense and sparse MMA intrinsics with block scaling. The implementation is based on PTX ISA version 9.0. Tests for new intrinsics are added for PTX 8.7 and SM 120a and are generated by
llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py. The tests have been verified with ptxas from CUDA-13.0 release.Dense MMA intrinsics with block scaling were supported by @schwarzschild-radius.