Skip to content

Commit 7be3c3a

Browse files
[NVVM][NVPTX] Add support for tcgen05.mma (llvm#151949)
This commit adds support for tcgen05.mma instructions in NVPTX with tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple flag arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is documented in NVPTXUsage.rst file. For more details, please refer the [PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/?a#tcgen05-mma-instructions)
1 parent e7515ee commit 7be3c3a

17 files changed

+5065
-4
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 458 additions & 0 deletions
Large diffs are not rendered by default.

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 209 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,78 @@ class NVVM_TCGEN05_LDST_ACCESS_SIZE<string Shape, int Num> {
947947
true : llvm_void_ty);
948948
}
949949

950+
class NVVM_TCGEN05_MMA_BASE<string Space, bit Sp> {
951+
LLVMType a_operand_type = !if(!eq(Space, "tensor"),
952+
llvm_tmem_ptr_ty, llvm_i64_ty);
953+
list<LLVMType> common_args = !listconcat(
954+
[llvm_tmem_ptr_ty, // d
955+
a_operand_type, // a
956+
llvm_i64_ty, // b
957+
llvm_i32_ty, // idesc
958+
llvm_i1_ty], // enable_input_d
959+
!if(!eq(Sp, 1), [llvm_tmem_ptr_ty], [])); // spmetadata
960+
list<IntrinsicProperty> common_intr_props = !listconcat(
961+
[IntrArgMemOnly, WriteOnly<ArgIndex<0>>],
962+
!if(!eq(Space, "tensor"), [ReadOnly<ArgIndex<1>>], [])
963+
);
964+
}
965+
966+
class NVVM_TCGEN05_MMA<bit Sp, string Space,
967+
bit AShift, bit ScaleInputD>:
968+
NVVM_TCGEN05_MMA_BASE<Space, Sp> {
969+
string intr = "llvm.nvvm.tcgen05.mma"
970+
# !if(!eq(Sp, 1), ".sp", "")
971+
# "." # Space
972+
# !if(!eq(ScaleInputD, 1), ".scale_d", "")
973+
# !if(!eq(AShift, 1), ".ashift", "");
974+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
975+
}
976+
977+
class NVVM_TCGEN05_MMA_BLOCKSCALE<bit Sp, string Space,
978+
string Kind, string ScaleVecSize>:
979+
NVVM_TCGEN05_MMA_BASE<Space, Sp> {
980+
string intr = "llvm.nvvm.tcgen05.mma"
981+
# !if(!eq(Sp, 1), ".sp", "")
982+
# "." # Space
983+
# "." # Kind
984+
# ".block_scale" # ScaleVecSize;
985+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
986+
}
987+
988+
class NVVM_TCGEN05_MMA_WS<bit Sp, string Space, bit ZeroColMask>:
989+
NVVM_TCGEN05_MMA_BASE<Space, Sp> {
990+
string intr = "llvm.nvvm.tcgen05.mma.ws"
991+
# !if(!eq(Sp, 1), ".sp", "")
992+
# "." # Space
993+
# !if(!eq(ZeroColMask, 1), ".zero_col_mask", "");
994+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
995+
}
996+
997+
class NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<bit Sp, string Space,
998+
int CtaGroup, bit AShift,
999+
bit ScaleInputD>:
1000+
NVVM_TCGEN05_MMA_BASE<Space, Sp> {
1001+
string intr = "llvm.nvvm.tcgen05.mma"
1002+
# !if(!eq(Sp, 1), ".sp", "")
1003+
# "." # Space
1004+
# !if(!eq(ScaleInputD, 1), ".scale_d", "")
1005+
# ".disable_output_lane.cg" # CtaGroup
1006+
# !if(!eq(AShift, 1), ".ashift", "");
1007+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
1008+
}
1009+
1010+
class NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<string Kind, string ScaleVecSize> {
1011+
bit ret = !cond(
1012+
!and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, "")) : true,
1013+
!and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, "")) : true,
1014+
!and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block16")) : true,
1015+
!and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, ".block32")) : true,
1016+
!and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block32")) : true,
1017+
!and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, ".block32")) : true,
1018+
true: false
1019+
);
1020+
}
1021+
9501022
class TexVector<string name, list<LLVMType> types> {
9511023
string Name = name;
9521024
list<LLVMType> Types = types;
@@ -2268,13 +2340,15 @@ def int_nvvm_exit : NVVMBuiltin,
22682340
class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
22692341
list<LLVMType> param_types,
22702342
list<LLVMType> flags,
2271-
list<IntrinsicProperty> intr_properties>
2343+
list<IntrinsicProperty> intr_properties,
2344+
string name = "">
22722345
: DefaultAttrsIntrinsic<
22732346
ret_types,
22742347
!listconcat(param_types, flags),
22752348
!listconcat(intr_properties,
22762349
!foreach(i, !range(flags),
2277-
ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
2350+
ImmArg<ArgIndex<!add(i, !size(param_types))>>)),
2351+
name>;
22782352

22792353
// TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
22802354
foreach dim = 1...5 in {
@@ -2663,4 +2737,136 @@ foreach dim = ["x", "y", "z"] in
26632737
: PureIntrinsic<[llvm_i32_ty], [llvm_i128_ty], [],
26642738
"llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
26652739

2666-
} // let TargetPrefix = "nvvm"
2740+
//
2741+
// tcgen05.mma intrinsics
2742+
//
2743+
2744+
foreach sp = [0, 1] in {
2745+
foreach space = ["tensor", "shared"] in {
2746+
foreach scale_d = [0, 1] in {
2747+
foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2748+
defvar mma = NVVM_TCGEN05_MMA<sp, space, ashift, scale_d>;
2749+
defvar args = !listconcat(
2750+
mma.common_args,
2751+
!if(!eq(scale_d, 1), [llvm_i64_ty], []) // scale_d_imm
2752+
);
2753+
defvar flags = [llvm_i32_ty, // kind
2754+
llvm_i32_ty, // cta_group
2755+
llvm_i32_ty]; // collector_usage_a
2756+
defvar nargs = !size(args);
2757+
defvar scale_d_imm = ArgIndex<!sub(nargs, 1)>;
2758+
defvar scale_d_imm_range = [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>];
2759+
defvar intrinsic_properties = !listconcat(
2760+
mma.common_intr_props,
2761+
!if(!eq(scale_d, 1), scale_d_imm_range, []),
2762+
[Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>, // kind
2763+
Range<ArgIndex<!add(nargs, 1)>, 1, 3>, // cta_group
2764+
Range<ArgIndex<!add(nargs, 2)>, 0,
2765+
!if(!eq(ashift, 1), 2, 4)> // collector_usage
2766+
]
2767+
);
2768+
2769+
def mma.record:
2770+
DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2771+
mma.intr>;
2772+
}
2773+
}
2774+
}
2775+
}
2776+
2777+
//
2778+
// tcgen05.mma disable_output_lane intrinsics
2779+
//
2780+
foreach sp = [0, 1] in {
2781+
foreach space = ["tensor", "shared"] in {
2782+
foreach cta_group = [1, 2] in {
2783+
foreach scale_d = [0, 1] in {
2784+
foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2785+
defvar mma = NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<
2786+
sp, space, cta_group, ashift, scale_d>;
2787+
defvar disable_output_lane_type =
2788+
!if(!eq(cta_group, 1), llvm_v4i32_ty, llvm_v8i32_ty);
2789+
defvar args = !listconcat(
2790+
mma.common_args,
2791+
!if(!eq(scale_d, 1), [llvm_i64_ty], []),
2792+
[disable_output_lane_type]
2793+
);
2794+
defvar flags = [llvm_i32_ty, // kind_flag
2795+
llvm_i32_ty]; // collector_usage_a_flag
2796+
defvar nargs = !size(args);
2797+
defvar scale_d_imm = ArgIndex<!sub(nargs, 2)>;
2798+
defvar scale_d_imm_range = [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>];
2799+
defvar intrinsic_properties = !listconcat(
2800+
mma.common_intr_props,
2801+
!if(!eq(scale_d, 1), scale_d_imm_range, []),
2802+
[Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
2803+
Range<ArgIndex<!add(nargs, 1)>, 0, !if(!eq(ashift, 1), 2, 4)>]
2804+
);
2805+
2806+
def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2807+
mma.intr>;
2808+
} // ashift
2809+
} // scale_d
2810+
} // cta_group
2811+
} // space
2812+
} // sp
2813+
2814+
//
2815+
// tcgen05.mma block_scale intrinsics
2816+
//
2817+
foreach sp = [0, 1] in {
2818+
foreach space = ["tensor", "shared"] in {
2819+
foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
2820+
foreach scale_vec_size = ["", ".block16", ".block32"] in {
2821+
defvar mma = NVVM_TCGEN05_MMA_BLOCKSCALE<sp, space, kind, scale_vec_size>;
2822+
defvar args = !listconcat(mma.common_args,
2823+
[llvm_tmem_ptr_ty, // scale_a
2824+
llvm_tmem_ptr_ty]); // scale_b
2825+
defvar flags = [llvm_i32_ty, // cta_group
2826+
llvm_i32_ty]; // collector_usage_a
2827+
defvar nargs = !size(args);
2828+
defvar cta_group = ArgIndex<nargs>;
2829+
defvar collector_usage = ArgIndex<!add(nargs, 1)>;
2830+
2831+
if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
2832+
def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags,
2833+
!listconcat(mma.common_intr_props,
2834+
[Range<cta_group, 1, 3>,
2835+
Range<collector_usage, 0, 4>]),
2836+
mma.intr>;
2837+
}
2838+
}
2839+
}
2840+
}
2841+
}
2842+
2843+
//
2844+
// tcgen05.mma ws intrinsics
2845+
//
2846+
foreach sp = [0, 1] in {
2847+
foreach space = ["tensor", "shared"] in {
2848+
foreach zero_col_mask = [0, 1] in {
2849+
defvar mma = NVVM_TCGEN05_MMA_WS<sp, space, zero_col_mask>;
2850+
defvar args = !listconcat(
2851+
mma.common_args,
2852+
!if(!eq(zero_col_mask, 1), [llvm_i64_ty], [])
2853+
);
2854+
defvar flags = [llvm_i32_ty, // kind
2855+
llvm_i32_ty, // collector_buffer_b
2856+
llvm_i32_ty]; // collector_usage_b_op
2857+
defvar nargs = !size(args);
2858+
defvar intrinsic_properties = !listconcat(
2859+
mma.common_intr_props,
2860+
[Range<ArgIndex<nargs>, 0, 4>,
2861+
Range<ArgIndex<!add(nargs, 1)>, 0, 4>,
2862+
Range<ArgIndex<!add(nargs, 2)>, 0, 4>]
2863+
);
2864+
2865+
def mma.record:
2866+
DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2867+
mma.intr>;
2868+
}
2869+
}
2870+
}
2871+
2872+
} // let TargetPrefix = "nvvm"

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ enum class CTAGroupKind : uint8_t {
4747
CG_2 = 2, // cta_group::2 modifier
4848
};
4949

50+
enum class Tcgen05MMAKind : uint8_t { F16 = 0, TF32 = 1, F8F6F4 = 2, I8 = 3 };
51+
52+
enum class Tcgen05CollectorUsageOp : uint8_t {
53+
DISCARD = 0,
54+
LASTUSE = 1,
55+
FILL = 2,
56+
USE = 3,
57+
};
58+
5059
inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
5160
switch (IntrinsicID) {
5261
case Intrinsic::nvvm_f2i_rm_ftz:

0 commit comments

Comments
 (0)