@@ -947,6 +947,78 @@ class NVVM_TCGEN05_LDST_ACCESS_SIZE<string Shape, int Num> {
947947 true : llvm_void_ty);
948948}
949949
950+ class NVVM_TCGEN05_MMA_BASE<string Space, bit Sp> {
951+ LLVMType a_operand_type = !if(!eq(Space, "tensor"),
952+ llvm_tmem_ptr_ty, llvm_i64_ty);
953+ list<LLVMType> common_args = !listconcat(
954+ [llvm_tmem_ptr_ty, // d
955+ a_operand_type, // a
956+ llvm_i64_ty, // b
957+ llvm_i32_ty, // idesc
958+ llvm_i1_ty], // enable_input_d
959+ !if(!eq(Sp, 1), [llvm_tmem_ptr_ty], [])); // spmetadata
960+ list<IntrinsicProperty> common_intr_props = !listconcat(
961+ [IntrArgMemOnly, WriteOnly<ArgIndex<0>>],
962+ !if(!eq(Space, "tensor"), [ReadOnly<ArgIndex<1>>], [])
963+ );
964+ }
965+
966+ class NVVM_TCGEN05_MMA<bit Sp, string Space,
967+ bit AShift, bit ScaleInputD>:
968+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
969+ string intr = "llvm.nvvm.tcgen05.mma"
970+ # !if(!eq(Sp, 1), ".sp", "")
971+ # "." # Space
972+ # !if(!eq(ScaleInputD, 1), ".scale_d", "")
973+ # !if(!eq(AShift, 1), ".ashift", "");
974+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
975+ }
976+
977+ class NVVM_TCGEN05_MMA_BLOCKSCALE<bit Sp, string Space,
978+ string Kind, string ScaleVecSize>:
979+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
980+ string intr = "llvm.nvvm.tcgen05.mma"
981+ # !if(!eq(Sp, 1), ".sp", "")
982+ # "." # Space
983+ # "." # Kind
984+ # ".block_scale" # ScaleVecSize;
985+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
986+ }
987+
988+ class NVVM_TCGEN05_MMA_WS<bit Sp, string Space, bit ZeroColMask>:
989+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
990+ string intr = "llvm.nvvm.tcgen05.mma.ws"
991+ # !if(!eq(Sp, 1), ".sp", "")
992+ # "." # Space
993+ # !if(!eq(ZeroColMask, 1), ".zero_col_mask", "");
994+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
995+ }
996+
997+ class NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<bit Sp, string Space,
998+ int CtaGroup, bit AShift,
999+ bit ScaleInputD>:
1000+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
1001+ string intr = "llvm.nvvm.tcgen05.mma"
1002+ # !if(!eq(Sp, 1), ".sp", "")
1003+ # "." # Space
1004+ # !if(!eq(ScaleInputD, 1), ".scale_d", "")
1005+ # ".disable_output_lane.cg" # CtaGroup
1006+ # !if(!eq(AShift, 1), ".ashift", "");
1007+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
1008+ }
1009+
1010+ class NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<string Kind, string ScaleVecSize> {
1011+ bit ret = !cond(
1012+ !and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, "")) : true,
1013+ !and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, "")) : true,
1014+ !and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block16")) : true,
1015+ !and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, ".block32")) : true,
1016+ !and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block32")) : true,
1017+ !and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, ".block32")) : true,
1018+ true: false
1019+ );
1020+ }
1021+
9501022class TexVector<string name, list<LLVMType> types> {
9511023 string Name = name;
9521024 list<LLVMType> Types = types;
@@ -2268,13 +2340,15 @@ def int_nvvm_exit : NVVMBuiltin,
22682340class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
22692341 list<LLVMType> param_types,
22702342 list<LLVMType> flags,
2271- list<IntrinsicProperty> intr_properties>
2343+ list<IntrinsicProperty> intr_properties,
2344+ string name = "">
22722345 : DefaultAttrsIntrinsic<
22732346 ret_types,
22742347 !listconcat(param_types, flags),
22752348 !listconcat(intr_properties,
22762349 !foreach(i, !range(flags),
2277- ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
2350+ ImmArg<ArgIndex<!add(i, !size(param_types))>>)),
2351+ name>;
22782352
22792353// TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
22802354foreach dim = 1...5 in {
@@ -2663,4 +2737,136 @@ foreach dim = ["x", "y", "z"] in
26632737 : PureIntrinsic<[llvm_i32_ty], [llvm_i128_ty], [],
26642738 "llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
26652739
2666- } // let TargetPrefix = "nvvm"
2740+ //
2741+ // tcgen05.mma intrinsics
2742+ //
2743+
2744+ foreach sp = [0, 1] in {
2745+ foreach space = ["tensor", "shared"] in {
2746+ foreach scale_d = [0, 1] in {
2747+ foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2748+ defvar mma = NVVM_TCGEN05_MMA<sp, space, ashift, scale_d>;
2749+ defvar args = !listconcat(
2750+ mma.common_args,
2751+ !if(!eq(scale_d, 1), [llvm_i64_ty], []) // scale_d_imm
2752+ );
2753+ defvar flags = [llvm_i32_ty, // kind
2754+ llvm_i32_ty, // cta_group
2755+ llvm_i32_ty]; // collector_usage_a
2756+ defvar nargs = !size(args);
2757+ defvar scale_d_imm = ArgIndex<!sub(nargs, 1)>;
2758+ defvar scale_d_imm_range = [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>];
2759+ defvar intrinsic_properties = !listconcat(
2760+ mma.common_intr_props,
2761+ !if(!eq(scale_d, 1), scale_d_imm_range, []),
2762+ [Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>, // kind
2763+ Range<ArgIndex<!add(nargs, 1)>, 1, 3>, // cta_group
2764+ Range<ArgIndex<!add(nargs, 2)>, 0,
2765+ !if(!eq(ashift, 1), 2, 4)> // collector_usage
2766+ ]
2767+ );
2768+
2769+ def mma.record:
2770+ DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2771+ mma.intr>;
2772+ }
2773+ }
2774+ }
2775+ }
2776+
2777+ //
2778+ // tcgen05.mma disable_output_lane intrinsics
2779+ //
2780+ foreach sp = [0, 1] in {
2781+ foreach space = ["tensor", "shared"] in {
2782+ foreach cta_group = [1, 2] in {
2783+ foreach scale_d = [0, 1] in {
2784+ foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2785+ defvar mma = NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<
2786+ sp, space, cta_group, ashift, scale_d>;
2787+ defvar disable_output_lane_type =
2788+ !if(!eq(cta_group, 1), llvm_v4i32_ty, llvm_v8i32_ty);
2789+ defvar args = !listconcat(
2790+ mma.common_args,
2791+ !if(!eq(scale_d, 1), [llvm_i64_ty], []),
2792+ [disable_output_lane_type]
2793+ );
2794+ defvar flags = [llvm_i32_ty, // kind_flag
2795+ llvm_i32_ty]; // collector_usage_a_flag
2796+ defvar nargs = !size(args);
2797+ defvar scale_d_imm = ArgIndex<!sub(nargs, 2)>;
2798+ defvar scale_d_imm_range = [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>];
2799+ defvar intrinsic_properties = !listconcat(
2800+ mma.common_intr_props,
2801+ !if(!eq(scale_d, 1), scale_d_imm_range, []),
2802+ [Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
2803+ Range<ArgIndex<!add(nargs, 1)>, 0, !if(!eq(ashift, 1), 2, 4)>]
2804+ );
2805+
2806+ def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2807+ mma.intr>;
2808+ } // ashift
2809+ } // scale_d
2810+ } // cta_group
2811+ } // space
2812+ } // sp
2813+
2814+ //
2815+ // tcgen05.mma block_scale intrinsics
2816+ //
2817+ foreach sp = [0, 1] in {
2818+ foreach space = ["tensor", "shared"] in {
2819+ foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
2820+ foreach scale_vec_size = ["", ".block16", ".block32"] in {
2821+ defvar mma = NVVM_TCGEN05_MMA_BLOCKSCALE<sp, space, kind, scale_vec_size>;
2822+ defvar args = !listconcat(mma.common_args,
2823+ [llvm_tmem_ptr_ty, // scale_a
2824+ llvm_tmem_ptr_ty]); // scale_b
2825+ defvar flags = [llvm_i32_ty, // cta_group
2826+ llvm_i32_ty]; // collector_usage_a
2827+ defvar nargs = !size(args);
2828+ defvar cta_group = ArgIndex<nargs>;
2829+ defvar collector_usage = ArgIndex<!add(nargs, 1)>;
2830+
2831+ if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
2832+ def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags,
2833+ !listconcat(mma.common_intr_props,
2834+ [Range<cta_group, 1, 3>,
2835+ Range<collector_usage, 0, 4>]),
2836+ mma.intr>;
2837+ }
2838+ }
2839+ }
2840+ }
2841+ }
2842+
2843+ //
2844+ // tcgen05.mma ws intrinsics
2845+ //
2846+ foreach sp = [0, 1] in {
2847+ foreach space = ["tensor", "shared"] in {
2848+ foreach zero_col_mask = [0, 1] in {
2849+ defvar mma = NVVM_TCGEN05_MMA_WS<sp, space, zero_col_mask>;
2850+ defvar args = !listconcat(
2851+ mma.common_args,
2852+ !if(!eq(zero_col_mask, 1), [llvm_i64_ty], [])
2853+ );
2854+ defvar flags = [llvm_i32_ty, // kind
2855+ llvm_i32_ty, // collector_buffer_b
2856+ llvm_i32_ty]; // collector_usage_b_op
2857+ defvar nargs = !size(args);
2858+ defvar intrinsic_properties = !listconcat(
2859+ mma.common_intr_props,
2860+ [Range<ArgIndex<nargs>, 0, 4>,
2861+ Range<ArgIndex<!add(nargs, 1)>, 0, 4>,
2862+ Range<ArgIndex<!add(nargs, 2)>, 0, 4>]
2863+ );
2864+
2865+ def mma.record:
2866+ DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2867+ mma.intr>;
2868+ }
2869+ }
2870+ }
2871+
2872+ } // let TargetPrefix = "nvvm"
0 commit comments