@@ -277,6 +277,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
277277 !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
278278 !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
279279
280+ // mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
281+ !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
282+ !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
283+
280284 // wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
281285 // All other supported geometries use the same fragment format for f32 and
282286 // f16, so we only need to consider {fragment, type}.
@@ -520,6 +524,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
520524 # signature;
521525}
522526
527+ class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
528+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
529+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
530+ string record_name = "int_nvvm_mma_block_scale"
531+ # "_" # A.geom
532+ # "_row_col"
533+ # "_" # Kind
534+ # !subst(".", "_", ScaleVecSize)
535+ # signature
536+ # "_" # SType;
537+ }
538+
523539class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
524540 WMMA_REGS A, WMMA_REGS B,
525541 WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +549,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
533549 # signature;
534550}
535551
552+ class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
553+ WMMA_REGS A, WMMA_REGS B,
554+ WMMA_REGS C, WMMA_REGS D> {
555+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
556+ string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
557+ # "_" # A.geom
558+ # "_row_col"
559+ # "_" # Kind
560+ # !subst(".", "_", ScaleVecSize)
561+ # signature
562+ # "_" # SType;
563+ }
564+
536565// Helper class that takes an intrinsic name and construct a record name.
537566// Additionally, sets `intr_name` to be non-empty if the default name assigned
538567// to this intrinsic will not match the name given.
@@ -683,6 +712,18 @@ class NVVM_MMA_OPS {
683712 fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
684713 int_mma_ops, subint_mma_ops, bit_mma_ops);
685714
715+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
716+ ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
717+ >.ret;
718+
719+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
720+ ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
721+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
722+ >.ret;
723+
724+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
725+ mxf4_mma_ops, mxf8f6f4_mma_ops);
726+
686727 list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
687728 ["m16n8k16", "m16n8k32"],
688729 ["bf16"], [], ["f32"], [], true>.ret;
@@ -707,6 +748,18 @@ class NVVM_MMA_OPS {
707748 bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
708749 subint_mma_sp_ops, int_mma_sp_ops);
709750
751+ // combines available geoms and types for mxf4 and mxf4nvf4 kinds
752+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
753+ ["m16n8k128"],
754+ ["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
755+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
756+ ["m16n8k64"],
757+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
758+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
759+ ["f32"], [], true>.ret;
760+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
761+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
762+
710763 list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
711764 ["m16n16k16", "m32n8k16", "m8n32k16"],
712765 ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -900,6 +953,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
900953 );
901954}
902955
956+ class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
957+ string geom = frags[0].geom;
958+
959+ bit ret = !cond(
960+ !and(!eq(geom, "m16n8k64"),
961+ !eq(kind, "mxf4"),
962+ !or(!eq(scale_vec_size, ""),
963+ !eq(scale_vec_size, ".scale_2x")),
964+ !eq(stype, "ue8m0")) : true,
965+ !and(!eq(geom, "m16n8k64"),
966+ !eq(kind, "mxf4nvf4"),
967+ !eq(scale_vec_size, ".scale_2x"),
968+ !eq(stype, "ue8m0")) : true,
969+ !and(!eq(geom, "m16n8k64"),
970+ !eq(kind, "mxf4nvf4"),
971+ !eq(scale_vec_size, ".scale_4x"),
972+ !eq(stype, "ue4m3")) : true,
973+ !and(!eq(geom, "m16n8k32"),
974+ !eq(kind, "mxf8f6f4"),
975+ !or(!eq(scale_vec_size, ""),
976+ !eq(scale_vec_size, ".scale_1x")),
977+ !eq(stype, "ue8m0")) : true,
978+ true: false
979+ );
980+ }
981+
903982// Returns true if the fragment is valid for ldmatrix ops is supported;
904983// false otherwise.
905984// E.g.
@@ -998,6 +1077,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
9981077}
9991078
10001079
1080+ // Returns true if this combination of kind/scale_vec_size/stype
1081+ // for MMA.SP ops is supported;
1082+ // false otherwise.
1083+ // E.g.
1084+ // if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
1085+ // def : FOO<>; // The record will only be defined for supported ops.
1086+ //
1087+ class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
1088+ string stype, string scale_vec_size> {
1089+ // MMA.SP ops check both layouts.
1090+ string a_type = frags[0].ptx_elt_type;
1091+ string b_type = frags[1].ptx_elt_type;
1092+ string c_type = frags[2].ptx_elt_type;
1093+ string d_type = frags[3].ptx_elt_type;
1094+ string geom = frags[0].geom;
1095+
1096+ bit ret = !cond(
1097+ !and(!eq(geom, "m16n8k128"),
1098+ !eq(kind, "mxf4"),
1099+ !eq(stype, "ue8m0"),
1100+ !or(!eq(scale_vec_size, ""),
1101+ !eq(scale_vec_size, ".scale_2x"))): true,
1102+
1103+ !and(!eq(geom, "m16n8k128"),
1104+ !eq(kind, "mxf4nvf4"),
1105+ !eq(stype, "ue8m0"),
1106+ !eq(scale_vec_size, ".scale_2x")): true,
1107+
1108+ !and(!eq(geom, "m16n8k128"),
1109+ !eq(kind, "mxf4nvf4"),
1110+ !eq(stype, "ue4m3"),
1111+ !eq(scale_vec_size, ".scale_4x")): true,
1112+
1113+ !and(!eq(geom, "m16n8k64"),
1114+ !eq(kind, "mxf8f6f4"),
1115+ !eq(stype, "ue8m0"),
1116+ !or(!eq(scale_vec_size, ""),
1117+ !eq(scale_vec_size, ".scale_1x"))): true,
1118+
1119+ // All other are NOT OK.
1120+ true: false
1121+ );
1122+ }
1123+
1124+
10011125class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
10021126 string Suffix = !if(sync, "sync_", "")
10031127 # mode # "_"
@@ -2415,6 +2539,31 @@ foreach layout_a = ["row", "col"] in {
24152539 } // layout_b
24162540} // layout_a
24172541
2542+ class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2543+ : Intrinsic<D.regs,
2544+ !listconcat(A.regs, B.regs, C.regs,
2545+ [
2546+ llvm_i32_ty, // scale-a-data
2547+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2548+ llvm_i32_ty, // scale-b-data,
2549+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2550+ ]),
2551+ [IntrNoMem, IntrNoCallback]>;
2552+
2553+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2554+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2555+ foreach stype = ["ue8m0", "ue4m3"] in {
2556+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
2557+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2558+ def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2559+ op[0], op[1], op[2], op[3]>.record_name
2560+ : NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2561+ }
2562+ } // op
2563+ } // stype
2564+ } // scale_vec_size
2565+ } // kind
2566+
24182567// MMA.SP
24192568class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
24202569 : Intrinsic<D.regs,
@@ -2462,6 +2611,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
24622611 } // kind
24632612} // metadata
24642613
2614+ // MMA.SP BLOCK SCALE
2615+ class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2616+ : Intrinsic<D.regs,
2617+ !listconcat(A.regs, B.regs, C.regs,
2618+ [
2619+ llvm_i32_ty, // metadata
2620+ llvm_i32_ty, // sparsity selector
2621+ llvm_i32_ty, // scale-a-data
2622+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2623+ llvm_i32_ty, // scale-b-data
2624+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2625+ ])> {
2626+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
2627+
2628+ // The range [0;num_threads) is for the sparsity selector that indicates the threads
2629+ // which contribute metadata.
2630+ // According to PTX ISA 9.0, the sparsity selector is always 0
2631+ // for sparse MMA block scale instructions
2632+ int num_threads = 1;
2633+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
2634+ Range<ArgIndex<pos>, 0, num_threads>];
2635+ }
2636+
2637+ // According to PTX ISA 9.0
2638+ // a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
2639+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2640+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2641+ foreach stype = ["ue8m0", "ue4m3"] in {
2642+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
2643+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2644+ def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2645+ op[0], op[1], op[2], op[3]>.record_name
2646+ : NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2647+ }
2648+ } // op
2649+ } // stype
2650+ } // scale_vec_size
2651+ } // kind
2652+
24652653// LDMATRIX
24662654class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
24672655 : Intrinsic<Frag.regs, [llvm_anyptr_ty],
0 commit comments