Skip to content

Commit 4842b27

Browse files
[NVVM][NVPTX] Add support for tcgen05.mma
This commit adds support for tcgen05.mma instructions in NVPTX which tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is present documented in NVPTXUsage.rst file
1 parent f1eb869 commit 4842b27

14 files changed

+4402
-12
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 461 additions & 3 deletions
Large diffs are not rendered by default.

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,76 @@ class NVVM_TCGEN05_LDST_ACCESS_SIZE<string Shape, int Num> {
764764
true : llvm_void_ty);
765765
}
766766

767+
class NVVM_TCGEN05_MMA_BASE<string Space> {
768+
LLVMType a_operand_type = !if(!eq(Space, "tensor"),
769+
llvm_tmem_ptr_ty, llvm_i64_ty);
770+
list<LLVMType> common_args = [llvm_tmem_ptr_ty, // d
771+
a_operand_type, // a
772+
llvm_i64_ty, // b
773+
llvm_i32_ty, // idesc
774+
llvm_i1_ty]; // enable_input_d
775+
list<IntrinsicProperty> common_intr_props = !listconcat(
776+
[IntrArgMemOnly, WriteOnly<ArgIndex<0>>],
777+
!if(!eq(Space, "tensor"), [ReadOnly<ArgIndex<1>>], [])
778+
);
779+
}
780+
781+
class NVVM_TCGEN05_MMA<bit Sp, string Space,
782+
bit AShift, bit ScaleInputD>:
783+
NVVM_TCGEN05_MMA_BASE<Space> {
784+
string intr = "llvm.nvvm.tcgen05.mma"
785+
# !if(!eq(Sp, 1), ".sp", "")
786+
# "." # Space
787+
# !if(!eq(ScaleInputD, 1), ".scale_d", "")
788+
# !if(!eq(AShift, 1), ".ashift", "");
789+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
790+
}
791+
792+
class NVVM_TCGEN05_MMA_BLOCKSCALE<bit Sp, string Space,
793+
string Kind, string ScaleVecSize>:
794+
NVVM_TCGEN05_MMA_BASE<Space> {
795+
string intr = "llvm.nvvm.tcgen05.mma"
796+
# !if(!eq(Sp, 1), ".sp", "")
797+
# "." # Space
798+
# "." # Kind # ScaleVecSize
799+
# ".block_scale";
800+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
801+
}
802+
803+
class NVVM_TCGEN05_MMA_WS<bit Sp, string Space, bit ZeroColMask>:
804+
NVVM_TCGEN05_MMA_BASE<Space> {
805+
string intr = "llvm.nvvm.tcgen05.mma.ws"
806+
# !if(!eq(Sp, 1), ".sp", "")
807+
# "." # Space
808+
# !if(!eq(ZeroColMask, 1), ".zero_col_mask", "");
809+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
810+
}
811+
812+
class NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<bit Sp, string Space,
813+
int CtaGroup, bit AShift,
814+
bit ScaleInputD>:
815+
NVVM_TCGEN05_MMA_BASE<Space> {
816+
string intr = "llvm.nvvm.tcgen05.mma"
817+
# !if(!eq(Sp, 1), ".sp", "")
818+
# "." # Space
819+
# !if(!eq(ScaleInputD, 1), ".scale_d", "")
820+
# ".disable_output_lane.cg" # CtaGroup
821+
# !if(!eq(AShift, 1), ".ashift", "");
822+
string record = !subst(".", "_", !subst("llvm.", "int_", intr));
823+
}
824+
825+
class NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<string Kind, string ScaleVecSize> {
826+
bit ret = !cond(
827+
!and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, "")) : true,
828+
!and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, "")) : true,
829+
!and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block16")) : true,
830+
!and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, ".block32")) : true,
831+
!and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block32")) : true,
832+
!and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, ".block32")) : true,
833+
true: false
834+
);
835+
}
836+
767837
class TexVector<string name, list<LLVMType> types> {
768838
string Name = name;
769839
list<LLVMType> Types = types;
@@ -2070,13 +2140,15 @@ def int_nvvm_exit : NVVMBuiltin,
20702140
class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
20712141
list<LLVMType> param_types,
20722142
list<LLVMType> flags,
2073-
list<IntrinsicProperty> intr_properties>
2143+
list<IntrinsicProperty> intr_properties,
2144+
string name = "">
20742145
: DefaultAttrsIntrinsic<
20752146
ret_types,
20762147
!listconcat(param_types, flags),
20772148
!listconcat(intr_properties,
20782149
!foreach(i, !range(flags),
2079-
ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
2150+
ImmArg<ArgIndex<!add(i, !size(param_types))>>)),
2151+
name>;
20802152

20812153
// TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
20822154
foreach dim = 1...5 in {
@@ -2464,4 +2536,139 @@ def int_nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_ # dim
24642536
"llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
24652537
}
24662538

2467-
} // let TargetPrefix = "nvvm"
2539+
//
2540+
// tcgen05.mma intrinsics
2541+
//
2542+
2543+
foreach sp = [0, 1] in {
2544+
foreach space = ["tensor", "shared"] in {
2545+
foreach scale_d = [0, 1] in {
2546+
foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2547+
defvar mma = NVVM_TCGEN05_MMA<sp, space, ashift, scale_d>;
2548+
defvar args = !listconcat(
2549+
mma.common_args,
2550+
!if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
2551+
!if(!eq(scale_d, 1), [llvm_i64_ty], [])
2552+
);
2553+
defvar flags = [llvm_i32_ty, // kind
2554+
llvm_i32_ty, // cta_group
2555+
llvm_i32_ty]; // collector_usage_a
2556+
defvar nargs = !size(args);
2557+
defvar scale_d_imm = ArgIndex<!sub(nargs, 1)>;
2558+
defvar intrinsic_properties = !listconcat(
2559+
mma.common_intr_props,
2560+
!if(!eq(scale_d, 1),
2561+
[ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>], []),
2562+
[Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
2563+
Range<ArgIndex<!add(nargs, 1)>, 1, 3>,
2564+
Range<ArgIndex<!add(nargs, 2)>, 0, !if(!eq(ashift, 1), 2, 4)>]
2565+
);
2566+
2567+
def mma.record:
2568+
DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2569+
mma.intr>;
2570+
}
2571+
}
2572+
}
2573+
}
2574+
2575+
//
2576+
// tcgen05.mma disable_output_lane intrinsics
2577+
//
2578+
foreach sp = [0, 1] in {
2579+
foreach space = ["tensor", "shared"] in {
2580+
foreach cta_group = [1, 2] in {
2581+
foreach scale_d = [0, 1] in {
2582+
foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2583+
defvar mma = NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<
2584+
sp, space, cta_group, ashift, scale_d>;
2585+
defvar disable_output_lane_type =
2586+
!if(!eq(cta_group, 1), llvm_v4i32_ty, llvm_v8i32_ty);
2587+
defvar args = !listconcat(
2588+
mma.common_args,
2589+
!if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
2590+
!if(!eq(scale_d, 1), [llvm_i64_ty], []),
2591+
[disable_output_lane_type]
2592+
);
2593+
defvar flags = [llvm_i32_ty, // kind_flag
2594+
llvm_i32_ty]; // collector_usage_a_flag
2595+
defvar nargs = !size(args);
2596+
defvar scale_d_flag = ArgIndex<!sub(nargs, 2)>;
2597+
defvar scale_d_imm_range = [ImmArg<scale_d_flag>, Range<scale_d_flag, 0, 16>];
2598+
defvar intrinsic_properties = !listconcat(
2599+
mma.common_intr_props,
2600+
!if(!eq(scale_d, 1), scale_d_imm_range, []),
2601+
[Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
2602+
Range<ArgIndex<!add(nargs, 1)>, 0, !if(!eq(ashift, 1), 2, 4)>]
2603+
);
2604+
2605+
def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2606+
mma.intr>;
2607+
} // ashift
2608+
} // scale_d
2609+
} // cta_group
2610+
} // space
2611+
} // sp
2612+
2613+
//
2614+
// tcgen05.mma block_scale intrinsics
2615+
//
2616+
foreach sp = [0, 1] in {
2617+
foreach space = ["tensor", "shared"] in {
2618+
foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
2619+
foreach scale_vec_size = ["", ".block16", ".block32"] in {
2620+
defvar mma = NVVM_TCGEN05_MMA_BLOCKSCALE<sp, space, kind, scale_vec_size>;
2621+
defvar cta_group = ArgIndex<!if(!eq(sp, 1), 8, 7)>;
2622+
defvar collector_usage = ArgIndex<!if(!eq(sp, 1), 9, 8)>;
2623+
2624+
if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
2625+
def mma.record: DefaultAttrsIntrinsicFlags<[],
2626+
!listconcat(mma.common_args,
2627+
!if(!eq(sp, 1),
2628+
[llvm_tmem_ptr_ty], []), // spmetadata
2629+
[llvm_tmem_ptr_ty, // scale a
2630+
llvm_tmem_ptr_ty]), // scale b
2631+
// flags
2632+
[llvm_i32_ty, // cta_group
2633+
llvm_i32_ty], // collector_usage_a
2634+
!listconcat(mma.common_intr_props,
2635+
[Range<cta_group, 1, 3>,
2636+
Range<collector_usage, 0, 4>]),
2637+
mma.intr>;
2638+
}
2639+
}
2640+
}
2641+
}
2642+
}
2643+
2644+
//
2645+
// tcgen05.mma ws intrinsics
2646+
//
2647+
foreach sp = [0, 1] in {
2648+
foreach space = ["tensor", "shared"] in {
2649+
foreach zero_col_mask = [0, 1] in {
2650+
defvar mma = NVVM_TCGEN05_MMA_WS<sp, space, zero_col_mask>;
2651+
defvar args = !listconcat(
2652+
mma.common_args,
2653+
!if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
2654+
!if(!eq(zero_col_mask, 1), [llvm_i64_ty], [])
2655+
);
2656+
defvar flags = [llvm_i32_ty, // kind
2657+
llvm_i32_ty, // collector_buffer_b
2658+
llvm_i32_ty]; // collector_usage_b_op
2659+
defvar nargs = !size(args);
2660+
defvar intrinsic_properties = !listconcat(
2661+
mma.common_intr_props,
2662+
[Range<ArgIndex<nargs>, 0, 4>,
2663+
Range<ArgIndex<!add(nargs, 1)>, 0, 4>,
2664+
Range<ArgIndex<!add(nargs, 2)>, 0, 4>]
2665+
);
2666+
2667+
def mma.record:
2668+
DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2669+
mma.intr>;
2670+
}
2671+
}
2672+
}
2673+
2674+
} // let TargetPrefix = "nvvm"

llvm/include/llvm/IR/NVVMIntrinsicUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@ enum class CTAGroupKind : uint8_t {
4747
CG_2 = 2, // cta_group::2 modifier
4848
};
4949

50+
enum class Tcgen05MMAKind : uint8_t { F16 = 0, TF32 = 1, F8F6F4 = 2, I8 = 3 };
51+
52+
enum class Tcgen05CollectorUsageOp : uint8_t {
53+
DISCARD = 0,
54+
LASTUSE = 1,
55+
FILL = 2,
56+
USE = 3,
57+
};
58+
5059
inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
5160
switch (IntrinsicID) {
5261
case Intrinsic::nvvm_f2i_rm_ftz:

0 commit comments

Comments
 (0)