Skip to content

Commit 961cc0f

Browse files
committed
[NVPTX] Support for dense and sparse MMA intrinsics with block scaling.
1 parent 3f61402 commit 961cc0f

File tree

3 files changed

+657
-3
lines changed

3 files changed

+657
-3
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
277277
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
278278
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
279279

280+
// mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
281+
!eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
282+
!eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
283+
280284
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
281285
// All other supported geometries use the same fragment format for f32 and
282286
// f16, so we only need to consider {fragment, type}.
@@ -520,6 +524,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
520524
# signature;
521525
}
522526

527+
class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
528+
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
529+
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
530+
string record_name = "int_nvvm_mma_block_scale"
531+
# "_" # A.geom
532+
# "_row_col"
533+
# "_" # Kind
534+
# !subst(".", "_", ScaleVecSize)
535+
# signature
536+
# "_" # SType;
537+
}
538+
523539
class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
524540
WMMA_REGS A, WMMA_REGS B,
525541
WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +549,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
533549
# signature;
534550
}
535551

552+
class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
553+
WMMA_REGS A, WMMA_REGS B,
554+
WMMA_REGS C, WMMA_REGS D> {
555+
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
556+
string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
557+
# "_" # A.geom
558+
# "_row_col"
559+
# "_" # Kind
560+
# !subst(".", "_", ScaleVecSize)
561+
# signature
562+
# "_" # SType;
563+
}
564+
536565
// Helper class that takes an intrinsic name and construct a record name.
537566
// Additionally, sets `intr_name` to be non-empty if the default name assigned
538567
// to this intrinsic will not match the name given.
@@ -683,6 +712,18 @@ class NVVM_MMA_OPS {
683712
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
684713
int_mma_ops, subint_mma_ops, bit_mma_ops);
685714

715+
list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
716+
["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
717+
>.ret;
718+
719+
list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
720+
["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
721+
["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
722+
>.ret;
723+
724+
list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
725+
mxf4_mma_ops, mxf8f6f4_mma_ops);
726+
686727
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
687728
["m16n8k16", "m16n8k32"],
688729
["bf16"], [], ["f32"], [], true>.ret;
@@ -707,6 +748,18 @@ class NVVM_MMA_OPS {
707748
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
708749
subint_mma_sp_ops, int_mma_sp_ops);
709750

751+
// combines available geoms and types for mxf4 and mxf4nvf4 kinds
752+
list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
753+
["m16n8k128"],
754+
["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
755+
list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
756+
["m16n8k64"],
757+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
758+
["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
759+
["f32"], [], true>.ret;
760+
list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
761+
mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
762+
710763
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
711764
["m16n16k16", "m32n8k16", "m8n32k16"],
712765
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -900,6 +953,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
900953
);
901954
}
902955

956+
class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
957+
string geom = frags[0].geom;
958+
959+
bit ret = !cond(
960+
!and(!eq(geom, "m16n8k64"),
961+
!eq(kind, "mxf4"),
962+
!or(!eq(scale_vec_size, ""),
963+
!eq(scale_vec_size, ".scale_2x")),
964+
!eq(stype, "ue8m0")) : true,
965+
!and(!eq(geom, "m16n8k64"),
966+
!eq(kind, "mxf4nvf4"),
967+
!eq(scale_vec_size, ".scale_2x"),
968+
!eq(stype, "ue8m0")) : true,
969+
!and(!eq(geom, "m16n8k64"),
970+
!eq(kind, "mxf4nvf4"),
971+
!eq(scale_vec_size, ".scale_4x"),
972+
!eq(stype, "ue4m3")) : true,
973+
!and(!eq(geom, "m16n8k32"),
974+
!eq(kind, "mxf8f6f4"),
975+
!or(!eq(scale_vec_size, ""),
976+
!eq(scale_vec_size, ".scale_1x")),
977+
!eq(stype, "ue8m0")) : true,
978+
true: false
979+
);
980+
}
981+
903982
// Returns true if the fragment is valid for ldmatrix ops is supported;
904983
// false otherwise.
905984
// E.g.
@@ -998,6 +1077,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
9981077
}
9991078

10001079

1080+
// Returns true if this combination of kind/scale_vec_size/stype
1081+
// for MMA.SP ops is supported;
1082+
// false otherwise.
1083+
// E.g.
1084+
// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
1085+
// def : FOO<>; // The record will only be defined for supported ops.
1086+
//
1087+
class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
1088+
string stype, string scale_vec_size> {
1089+
// MMA.SP ops check both layouts.
1090+
string a_type = frags[0].ptx_elt_type;
1091+
string b_type = frags[1].ptx_elt_type;
1092+
string c_type = frags[2].ptx_elt_type;
1093+
string d_type = frags[3].ptx_elt_type;
1094+
string geom = frags[0].geom;
1095+
1096+
bit ret = !cond(
1097+
!and(!eq(geom, "m16n8k128"),
1098+
!eq(kind, "mxf4"),
1099+
!eq(stype, "ue8m0"),
1100+
!or(!eq(scale_vec_size, ""),
1101+
!eq(scale_vec_size, ".scale_2x"))): true,
1102+
1103+
!and(!eq(geom, "m16n8k128"),
1104+
!eq(kind, "mxf4nvf4"),
1105+
!eq(stype, "ue8m0"),
1106+
!eq(scale_vec_size, ".scale_2x")): true,
1107+
1108+
!and(!eq(geom, "m16n8k128"),
1109+
!eq(kind, "mxf4nvf4"),
1110+
!eq(stype, "ue4m3"),
1111+
!eq(scale_vec_size, ".scale_4x")): true,
1112+
1113+
!and(!eq(geom, "m16n8k64"),
1114+
!eq(kind, "mxf8f6f4"),
1115+
!eq(stype, "ue8m0"),
1116+
!or(!eq(scale_vec_size, ""),
1117+
!eq(scale_vec_size, ".scale_1x"))): true,
1118+
1119+
// All other are NOT OK.
1120+
true: false
1121+
);
1122+
}
1123+
1124+
10011125
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
10021126
string Suffix = !if(sync, "sync_", "")
10031127
# mode # "_"
@@ -2415,6 +2539,31 @@ foreach layout_a = ["row", "col"] in {
24152539
} // layout_b
24162540
} // layout_a
24172541

2542+
class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2543+
: Intrinsic<D.regs,
2544+
!listconcat(A.regs, B.regs, C.regs,
2545+
[
2546+
llvm_i32_ty, // scale-a-data
2547+
llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2548+
llvm_i32_ty, // scale-b-data,
2549+
llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2550+
]),
2551+
[IntrNoMem, IntrNoCallback]>;
2552+
2553+
foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2554+
foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2555+
foreach stype = ["ue8m0", "ue4m3"] in {
2556+
foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
2557+
if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2558+
def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2559+
op[0], op[1], op[2], op[3]>.record_name
2560+
: NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2561+
}
2562+
} // op
2563+
} // stype
2564+
} // scale_vec_size
2565+
} // kind
2566+
24182567
// MMA.SP
24192568
class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
24202569
: Intrinsic<D.regs,
@@ -2462,6 +2611,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
24622611
} // kind
24632612
} // metadata
24642613

2614+
// MMA.SP BLOCK SCALE
2615+
class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
2616+
: Intrinsic<D.regs,
2617+
!listconcat(A.regs, B.regs, C.regs,
2618+
[
2619+
llvm_i32_ty, // metadata
2620+
llvm_i32_ty, // sparsity selector
2621+
llvm_i32_ty, // scale-a-data
2622+
llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
2623+
llvm_i32_ty, // scale-b-data
2624+
llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
2625+
])> {
2626+
int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
2627+
2628+
// The range [0;num_threads) is for the sparsity selector that indicates the threads
2629+
// which contribute metadata.
2630+
// According to PTX ISA 9.0, the sparsity selector is always 0
2631+
// for sparse MMA block scale instructions
2632+
int num_threads = 1;
2633+
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
2634+
Range<ArgIndex<pos>, 0, num_threads>];
2635+
}
2636+
2637+
// According to PTX ISA 9.0
2638+
// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
2639+
foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
2640+
foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
2641+
foreach stype = ["ue8m0", "ue4m3"] in {
2642+
foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
2643+
if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
2644+
def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
2645+
op[0], op[1], op[2], op[3]>.record_name
2646+
: NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
2647+
}
2648+
} // op
2649+
} // stype
2650+
} // scale_vec_size
2651+
} // kind
2652+
24652653
// LDMATRIX
24662654
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
24672655
: Intrinsic<Frag.regs, [llvm_anyptr_ty],

0 commit comments

Comments
 (0)