@@ -764,6 +764,76 @@ class NVVM_TCGEN05_LDST_ACCESS_SIZE<string Shape, int Num> {
764
764
true : llvm_void_ty);
765
765
}
766
766
767
+ class NVVM_TCGEN05_MMA_BASE<string Space> {
768
+ LLVMType a_operand_type = !if(!eq(Space, "tensor"),
769
+ llvm_tmem_ptr_ty, llvm_i64_ty);
770
+ list<LLVMType> common_args = [llvm_tmem_ptr_ty, // d
771
+ a_operand_type, // a
772
+ llvm_i64_ty, // b
773
+ llvm_i32_ty, // idesc
774
+ llvm_i1_ty]; // enable_input_d
775
+ list<IntrinsicProperty> common_intr_props = !listconcat(
776
+ [IntrArgMemOnly, WriteOnly<ArgIndex<0>>],
777
+ !if(!eq(Space, "tensor"), [ReadOnly<ArgIndex<1>>], [])
778
+ );
779
+ }
780
+
781
+ class NVVM_TCGEN05_MMA<bit Sp, string Space,
782
+ bit AShift, bit ScaleInputD>:
783
+ NVVM_TCGEN05_MMA_BASE<Space> {
784
+ string intr = "llvm.nvvm.tcgen05.mma"
785
+ # !if(!eq(Sp, 1), ".sp", "")
786
+ # "." # Space
787
+ # !if(!eq(ScaleInputD, 1), ".scale_d", "")
788
+ # !if(!eq(AShift, 1), ".ashift", "");
789
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
790
+ }
791
+
792
+ class NVVM_TCGEN05_MMA_BLOCKSCALE<bit Sp, string Space,
793
+ string Kind, string ScaleVecSize>:
794
+ NVVM_TCGEN05_MMA_BASE<Space> {
795
+ string intr = "llvm.nvvm.tcgen05.mma"
796
+ # !if(!eq(Sp, 1), ".sp", "")
797
+ # "." # Space
798
+ # "." # Kind # ScaleVecSize
799
+ # ".block_scale";
800
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
801
+ }
802
+
803
+ class NVVM_TCGEN05_MMA_WS<bit Sp, string Space, bit ZeroColMask>:
804
+ NVVM_TCGEN05_MMA_BASE<Space> {
805
+ string intr = "llvm.nvvm.tcgen05.mma.ws"
806
+ # !if(!eq(Sp, 1), ".sp", "")
807
+ # "." # Space
808
+ # !if(!eq(ZeroColMask, 1), ".zero_col_mask", "");
809
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
810
+ }
811
+
812
+ class NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<bit Sp, string Space,
813
+ int CtaGroup, bit AShift,
814
+ bit ScaleInputD>:
815
+ NVVM_TCGEN05_MMA_BASE<Space> {
816
+ string intr = "llvm.nvvm.tcgen05.mma"
817
+ # !if(!eq(Sp, 1), ".sp", "")
818
+ # "." # Space
819
+ # !if(!eq(ScaleInputD, 1), ".scale_d", "")
820
+ # ".disable_output_lane.cg" # CtaGroup
821
+ # !if(!eq(AShift, 1), ".ashift", "");
822
+ string record = !subst(".", "_", !subst("llvm.", "int_", intr));
823
+ }
824
+
825
+ class NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<string Kind, string ScaleVecSize> {
826
+ bit ret = !cond(
827
+ !and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, "")) : true,
828
+ !and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, "")) : true,
829
+ !and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block16")) : true,
830
+ !and(!eq(Kind, "mxf4"), !eq(ScaleVecSize, ".block32")) : true,
831
+ !and(!eq(Kind, "mxf4nvf4"), !eq(ScaleVecSize, ".block32")) : true,
832
+ !and(!eq(Kind, "mxf8f6f4"), !eq(ScaleVecSize, ".block32")) : true,
833
+ true: false
834
+ );
835
+ }
836
+
767
837
class TexVector<string name, list<LLVMType> types> {
768
838
string Name = name;
769
839
list<LLVMType> Types = types;
@@ -2070,13 +2140,15 @@ def int_nvvm_exit : NVVMBuiltin,
2070
2140
class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
2071
2141
list<LLVMType> param_types,
2072
2142
list<LLVMType> flags,
2073
- list<IntrinsicProperty> intr_properties>
2143
+ list<IntrinsicProperty> intr_properties,
2144
+ string name = "">
2074
2145
: DefaultAttrsIntrinsic<
2075
2146
ret_types,
2076
2147
!listconcat(param_types, flags),
2077
2148
!listconcat(intr_properties,
2078
2149
!foreach(i, !range(flags),
2079
- ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
2150
+ ImmArg<ArgIndex<!add(i, !size(param_types))>>)),
2151
+ name>;
2080
2152
2081
2153
// TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
2082
2154
foreach dim = 1...5 in {
@@ -2464,4 +2536,139 @@ def int_nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_ # dim
2464
2536
"llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
2465
2537
}
2466
2538
2467
- } // let TargetPrefix = "nvvm"
2539
+ //
2540
+ // tcgen05.mma intrinsics
2541
+ //
2542
+
2543
+ foreach sp = [0, 1] in {
2544
+ foreach space = ["tensor", "shared"] in {
2545
+ foreach scale_d = [0, 1] in {
2546
+ foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2547
+ defvar mma = NVVM_TCGEN05_MMA<sp, space, ashift, scale_d>;
2548
+ defvar args = !listconcat(
2549
+ mma.common_args,
2550
+ !if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
2551
+ !if(!eq(scale_d, 1), [llvm_i64_ty], [])
2552
+ );
2553
+ defvar flags = [llvm_i32_ty, // kind
2554
+ llvm_i32_ty, // cta_group
2555
+ llvm_i32_ty]; // collector_usage_a
2556
+ defvar nargs = !size(args);
2557
+ defvar scale_d_imm = ArgIndex<!sub(nargs, 1)>;
2558
+ defvar intrinsic_properties = !listconcat(
2559
+ mma.common_intr_props,
2560
+ !if(!eq(scale_d, 1),
2561
+ [ImmArg<scale_d_imm>, Range<scale_d_imm, 0, 16>], []),
2562
+ [Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
2563
+ Range<ArgIndex<!add(nargs, 1)>, 1, 3>,
2564
+ Range<ArgIndex<!add(nargs, 2)>, 0, !if(!eq(ashift, 1), 2, 4)>]
2565
+ );
2566
+
2567
+ def mma.record:
2568
+ DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2569
+ mma.intr>;
2570
+ }
2571
+ }
2572
+ }
2573
+ }
2574
+
2575
+ //
2576
+ // tcgen05.mma disable_output_lane intrinsics
2577
+ //
2578
+ foreach sp = [0, 1] in {
2579
+ foreach space = ["tensor", "shared"] in {
2580
+ foreach cta_group = [1, 2] in {
2581
+ foreach scale_d = [0, 1] in {
2582
+ foreach ashift = !if(!eq(space, "tensor"), [0, 1], [0]) in {
2583
+ defvar mma = NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE<
2584
+ sp, space, cta_group, ashift, scale_d>;
2585
+ defvar disable_output_lane_type =
2586
+ !if(!eq(cta_group, 1), llvm_v4i32_ty, llvm_v8i32_ty);
2587
+ defvar args = !listconcat(
2588
+ mma.common_args,
2589
+ !if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
2590
+ !if(!eq(scale_d, 1), [llvm_i64_ty], []),
2591
+ [disable_output_lane_type]
2592
+ );
2593
+ defvar flags = [llvm_i32_ty, // kind_flag
2594
+ llvm_i32_ty]; // collector_usage_a_flag
2595
+ defvar nargs = !size(args);
2596
+ defvar scale_d_flag = ArgIndex<!sub(nargs, 2)>;
2597
+ defvar scale_d_imm_range = [ImmArg<scale_d_flag>, Range<scale_d_flag, 0, 16>];
2598
+ defvar intrinsic_properties = !listconcat(
2599
+ mma.common_intr_props,
2600
+ !if(!eq(scale_d, 1), scale_d_imm_range, []),
2601
+ [Range<ArgIndex<nargs>, 0, !if(!eq(scale_d, 1), 2, 4)>,
2602
+ Range<ArgIndex<!add(nargs, 1)>, 0, !if(!eq(ashift, 1), 2, 4)>]
2603
+ );
2604
+
2605
+ def mma.record: DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2606
+ mma.intr>;
2607
+ } // ashift
2608
+ } // scale_d
2609
+ } // cta_group
2610
+ } // space
2611
+ } // sp
2612
+
2613
+ //
2614
+ // tcgen05.mma block_scale intrinsics
2615
+ //
2616
+ foreach sp = [0, 1] in {
2617
+ foreach space = ["tensor", "shared"] in {
2618
+ foreach kind = ["mxf8f6f4", "mxf4", "mxf4nvf4"] in {
2619
+ foreach scale_vec_size = ["", ".block16", ".block32"] in {
2620
+ defvar mma = NVVM_TCGEN05_MMA_BLOCKSCALE<sp, space, kind, scale_vec_size>;
2621
+ defvar cta_group = ArgIndex<!if(!eq(sp, 1), 8, 7)>;
2622
+ defvar collector_usage = ArgIndex<!if(!eq(sp, 1), 9, 8)>;
2623
+
2624
+ if NVVM_TCGEN05_MMA_BLOCKSCALE_SUPPORTED<kind, scale_vec_size>.ret then {
2625
+ def mma.record: DefaultAttrsIntrinsicFlags<[],
2626
+ !listconcat(mma.common_args,
2627
+ !if(!eq(sp, 1),
2628
+ [llvm_tmem_ptr_ty], []), // spmetadata
2629
+ [llvm_tmem_ptr_ty, // scale a
2630
+ llvm_tmem_ptr_ty]), // scale b
2631
+ // flags
2632
+ [llvm_i32_ty, // cta_group
2633
+ llvm_i32_ty], // collector_usage_a
2634
+ !listconcat(mma.common_intr_props,
2635
+ [Range<cta_group, 1, 3>,
2636
+ Range<collector_usage, 0, 4>]),
2637
+ mma.intr>;
2638
+ }
2639
+ }
2640
+ }
2641
+ }
2642
+ }
2643
+
2644
+ //
2645
+ // tcgen05.mma ws intrinsics
2646
+ //
2647
+ foreach sp = [0, 1] in {
2648
+ foreach space = ["tensor", "shared"] in {
2649
+ foreach zero_col_mask = [0, 1] in {
2650
+ defvar mma = NVVM_TCGEN05_MMA_WS<sp, space, zero_col_mask>;
2651
+ defvar args = !listconcat(
2652
+ mma.common_args,
2653
+ !if(!eq(sp, 1), [llvm_tmem_ptr_ty], []),
2654
+ !if(!eq(zero_col_mask, 1), [llvm_i64_ty], [])
2655
+ );
2656
+ defvar flags = [llvm_i32_ty, // kind
2657
+ llvm_i32_ty, // collector_buffer_b
2658
+ llvm_i32_ty]; // collector_usage_b_op
2659
+ defvar nargs = !size(args);
2660
+ defvar intrinsic_properties = !listconcat(
2661
+ mma.common_intr_props,
2662
+ [Range<ArgIndex<nargs>, 0, 4>,
2663
+ Range<ArgIndex<!add(nargs, 1)>, 0, 4>,
2664
+ Range<ArgIndex<!add(nargs, 2)>, 0, 4>]
2665
+ );
2666
+
2667
+ def mma.record:
2668
+ DefaultAttrsIntrinsicFlags<[], args, flags, intrinsic_properties,
2669
+ mma.intr>;
2670
+ }
2671
+ }
2672
+ }
2673
+
2674
+ } // let TargetPrefix = "nvvm"
0 commit comments