@@ -947,6 +947,78 @@ class NVVM_TCGEN05_LDST_ACCESS_SIZE<string Shape, int Num> {
947
947
true : llvm_void_ty);
948
948
}
949
949
950
+ class NVVM_TCGEN05_MMA_BASE<string Space, bit Sp> {
951
+ LLVMType a_operand_type = !if(!eq(Space, "tensor"),
952
+ llvm_tmem_ptr_ty, llvm_i64_ty);
953
+ list<LLVMType> common_args = !listconcat(
954
+ [llvm_tmem_ptr_ty, // d
955
+ a_operand_type, // a
956
+ llvm_i64_ty, // b
957
+ llvm_i32_ty, // idesc
958
+ llvm_i1_ty], // enable_input_d
959
+ !if(!eq(Sp, 1), [llvm_tmem_ptr_ty], [])); // spmetadata
960
+ list<IntrinsicProperty> common_intr_props = !listconcat(
961
+ [IntrArgMemOnly, WriteOnly<ArgIndex<0>>],
962
+ !if(!eq(Space, "tensor"), [ReadOnly<ArgIndex<1>>], [])
963
+ );
964
+ }
965
+
966
+ class NVVM_TCGEN05_MMA<bit Sp, string Space,
967
+ bit AShift, bit ScaleInputD>:
968
+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
969
+ string intr = "llvm.nvvm.tcgen05.mma"
970
+ # !if(!eq(Sp, 1), ".sp", "")
971
+ # "." # Space
972
+ # !if(!eq(ScaleInputD, 1), ".scale_d", "")
973
+ # !if(!eq(AShift, 1), ".ashift", "");
974
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
975
+ }
976
+
977
+ class NVVM_TCGEN05_MMA_BLOCKSCALE<bit Sp, string Space,
978
+ string Kind, string ScaleVecSize>:
979
+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
980
+ string intr = "llvm.nvvm.tcgen05.mma"
981
+ # !if(!eq(Sp, 1), ".sp", "")
982
+ # "." # Space
983
+ # "." # Kind
984
+ # ".block_scale" # ScaleVecSize;
985
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
986
+ }
987
+
988
+ class NVVM_TCGEN05_MMA_WS<bit Sp, string Space, bit ZeroColMask>:
989
+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
990
+ string intr = "llvm.nvvm.tcgen05.mma.ws"
991
+ # !if(!eq(Sp, 1), ".sp", "")
992
+ # "." # Space
993
+ # !if(!eq(ZeroColMask, 1), ".zero_col_mask", "");
994
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
995
+ }
996
+
997
+ class NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<bit Sp, string Space,
998
+ int CtaGroup, bit AShift,
999
+ bit ScaleInputD>:
1000
+ NVVM_TCGEN05_MMA_BASE<Space, Sp> {
1001
+ string intr = "llvm.nvvm.tcgen05.mma"
1002
+ # !if(!eq(Sp, 1), ".sp", "")
1003
+ # "." # Space
1004
+ # !if(!eq(ScaleInputD, 1), ".scale_d", "")
1005
+ # ".disable_output_lane.cg" # CtaGroup
1006
+ # !if(!eq(AShift, 1), ".ashift", "");
1007
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
1008
+ }
1009
+
1010
+ class NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<string Kind, string ScaleVecSize> {
1011
+ bit ret = !cond(
1012
+ !and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, "")) : true,
1013
+ !and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, "")) : true,
1014
+ !and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block16")) : true,
1015
+ !and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, ".block32")) : true,
1016
+ !and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block32")) : true,
1017
+ !and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, ".block32")) : true,
1018
+ true: false
1019
+ );
1020
+ }
1021
+
950
1022
class TexVector<string name, list<LLVMType> types> {
951
1023
string Name = name;
952
1024
list<LLVMType> Types = types;
@@ -2268,13 +2340,15 @@ def int_nvvm_exit : NVVMBuiltin,
2268
2340
class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
2269
2341
list<LLVMType> param_types,
2270
2342
list<LLVMType> flags,
2271
- list<IntrinsicProperty> intr_properties>
2343
+ list<IntrinsicProperty> intr_properties,
2344
+ string name = "">
2272
2345
: DefaultAttrsIntrinsic<
2273
2346
ret_types,
2274
2347
!listconcat(param_types, flags),
2275
2348
!listconcat(intr_properties,
2276
2349
!foreach(i, !range(flags),
2277
- ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
2350
+ ImmArg<ArgIndex<!add(i, !size(param_types))>>)),
2351
+ name>;
2278
2352
2279
2353
// TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
2280
2354
foreach dim = 1...5 in {
@@ -2663,4 +2737,136 @@ foreach dim = ["x", "y", "z"] in
2663
2737
: PureIntrinsic<[llvm_i32_ty], [llvm_i128_ty], [],
2664
2738
"llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
2665
2739
2666
- } // let TargetPrefix = "nvvm"
2740
+ //
2741
+ // tcgen05.mma intrinsics
2742
+ //
2743
+
2744
+ foreach sp = [0, 1] in {
2745
+ foreach space = ["tensor", "shared"] in {
2746
+ foreach scale_d = [0, 1] in {
2747
+ foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2748
+ defvar mma = NVVM_TCGEN05_MMA<sp, space, ashift, scale_d>;
2749
+ defvar args = !listconcat(
2750
+ mma.common_args,
2751
+ !if(!eq(scale_d, 1), [llvm_i64_ty], []) // scale_d_imm
2752
+ );
2753
+ defvar flags = [llvm_i32_ty, // kind
2754
+ llvm_i32_ty, // cta_group
2755
+ llvm_i32_ty]; // collector_usage_a
2756
+ defvar nargs = !size(args);
2757
+ defvar scale_d_imm = ArgIndex<!sub(nargs, 1)>;
2758
+ defvar scale_d_imm_range = [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>];
2759
+ defvar intrinsic_properties = !listconcat(
2760
+ mma.common_intr_props,
2761
+ !if(!eq(scale_d, 1), scale_d_imm_range, []),
2762
+ [Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>, // kind
2763
+ Range<ArgIndex<!add(nargs, 1)>, 1, 3>, // cta_group
2764
+ Range<ArgIndex<!add(nargs, 2)>, 0,
2765
+ !if(!eq(ashift, 1), 2, 4)> // collector_usage
2766
+ ]
2767
+ );
2768
+
2769
+ def mma.record:
2770
+ DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2771
+ mma.intr>;
2772
+ }
2773
+ }
2774
+ }
2775
+ }
2776
+
2777
+ //
2778
+ // tcgen05.mma disable_output_lane intrinsics
2779
+ //
2780
+ foreach sp = [0, 1] in {
2781
+ foreach space = ["tensor", "shared"] in {
2782
+ foreach cta_group = [1, 2] in {
2783
+ foreach scale_d = [0, 1] in {
2784
+ foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2785
+ defvar mma = NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<
2786
+ sp, space, cta_group, ashift, scale_d>;
2787
+ defvar disable_output_lane_type =
2788
+ !if(!eq(cta_group, 1), llvm_v4i32_ty, llvm_v8i32_ty);
2789
+ defvar args = !listconcat(
2790
+ mma.common_args,
2791
+ !if(!eq(scale_d, 1), [llvm_i64_ty], []),
2792
+ [disable_output_lane_type]
2793
+ );
2794
+ defvar flags = [llvm_i32_ty, // kind_flag
2795
+ llvm_i32_ty]; // collector_usage_a_flag
2796
+ defvar nargs = !size(args);
2797
+ defvar scale_d_imm = ArgIndex<!sub(nargs, 2)>;
2798
+ defvar scale_d_imm_range = [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>];
2799
+ defvar intrinsic_properties = !listconcat(
2800
+ mma.common_intr_props,
2801
+ !if(!eq(scale_d, 1), scale_d_imm_range, []),
2802
+ [Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
2803
+ Range<ArgIndex<!add(nargs, 1)>, 0, !if(!eq(ashift, 1), 2, 4)>]
2804
+ );
2805
+
2806
+ def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2807
+ mma.intr>;
2808
+ } // ashift
2809
+ } // scale_d
2810
+ } // cta_group
2811
+ } // space
2812
+ } // sp
2813
+
2814
+ //
2815
+ // tcgen05.mma block_scale intrinsics
2816
+ //
2817
+ foreach sp = [0, 1] in {
2818
+ foreach space = ["tensor", "shared"] in {
2819
+ foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
2820
+ foreach scale_vec_size = ["", ".block16", ".block32"] in {
2821
+ defvar mma = NVVM_TCGEN05_MMA_BLOCKSCALE<sp, space, kind, scale_vec_size>;
2822
+ defvar args = !listconcat(mma.common_args,
2823
+ [llvm_tmem_ptr_ty, // scale_a
2824
+ llvm_tmem_ptr_ty]); // scale_b
2825
+ defvar flags = [llvm_i32_ty, // cta_group
2826
+ llvm_i32_ty]; // collector_usage_a
2827
+ defvar nargs = !size(args);
2828
+ defvar cta_group = ArgIndex<nargs>;
2829
+ defvar collector_usage = ArgIndex<!add(nargs, 1)>;
2830
+
2831
+ if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
2832
+ def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags,
2833
+ !listconcat(mma.common_intr_props,
2834
+ [Range<cta_group, 1, 3>,
2835
+ Range<collector_usage, 0, 4>]),
2836
+ mma.intr>;
2837
+ }
2838
+ }
2839
+ }
2840
+ }
2841
+ }
2842
+
2843
+ //
2844
+ // tcgen05.mma ws intrinsics
2845
+ //
2846
+ foreach sp = [0, 1] in {
2847
+ foreach space = ["tensor", "shared"] in {
2848
+ foreach zero_col_mask = [0, 1] in {
2849
+ defvar mma = NVVM_TCGEN05_MMA_WS<sp, space, zero_col_mask>;
2850
+ defvar args = !listconcat(
2851
+ mma.common_args,
2852
+ !if(!eq(zero_col_mask, 1), [llvm_i64_ty], [])
2853
+ );
2854
+ defvar flags = [llvm_i32_ty, // kind
2855
+ llvm_i32_ty, // collector_buffer_b
2856
+ llvm_i32_ty]; // collector_usage_b_op
2857
+ defvar nargs = !size(args);
2858
+ defvar intrinsic_properties = !listconcat(
2859
+ mma.common_intr_props,
2860
+ [Range<ArgIndex<nargs>, 0, 4>,
2861
+ Range<ArgIndex<!add(nargs, 1)>, 0, 4>,
2862
+ Range<ArgIndex<!add(nargs, 2)>, 0, 4>]
2863
+ );
2864
+
2865
+ def mma.record:
2866
+ DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2867
+ mma.intr>;
2868
+ }
2869
+ }
2870
+ }
2871
+
2872
+ } // let TargetPrefix = "nvvm"
0 commit comments