@@ -946,6 +946,80 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
946946
947947// -----
948948
949+ #mma = #triton_gpu.nvidia_mma <{versionMajor = 2 , warpsPerCTA = [1 , 1 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
950+ #dot1 = #triton_gpu.dot_op <{opIdx =0 , parent =#mma , kWidth =2 }>
951+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 1 : i32 } {
952+ // CHECK-LABEL: convert_layout_mmav2_dot_reg
953+ tt.func @convert_layout_mmav2_dot_reg (%arg0: tensor <16 x16 xf16 , #mma >) {
954+ // CHECK-NOT: st.shared
955+ // CHECK-NOT: llvm.load
956+ %0 = triton_gpu.convert_layout %arg0 : tensor <16 x16 xf16 , #mma > -> tensor <16 x16 xf16 , #dot1 >
957+ tt.return
958+ }
959+ }
960+
961+ // -----
962+
963+ #mma0 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 64 , 16 ]}>
964+ #mma1 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 128 , 16 ]}>
965+
966+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 } {
967+ // CHECK-LABEL: convert_layout_mmav3_mmav3_0
968+ tt.func @convert_layout_mmav3_mmav3_0 (%arg0: tensor <64 x64 xf16 , #mma0 >) {
969+ // CHECK-NOT: st.shared
970+ // CHECK-NOT: llvm.load
971+ %0 = triton_gpu.convert_layout %arg0 : tensor <64 x64 xf16 , #mma0 > -> tensor <64 x64 xf16 , #mma1 >
972+ tt.return
973+ }
974+ }
975+
976+ // -----
977+
978+ #mma0 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 64 , 16 ]}>
979+ #mma1 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 128 , 16 ]}>
980+
981+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 } {
982+ // CHECK-LABEL: convert_layout_mmav3_mmav3_1
983+ tt.func @convert_layout_mmav3_mmav3_1 (%arg0: tensor <64 x64 xf16 , #mma1 >) {
984+ // CHECK-NOT: st.shared
985+ // CHECK-NOT: llvm.load
986+ %0 = triton_gpu.convert_layout %arg0 : tensor <64 x64 xf16 , #mma1 > -> tensor <64 x64 xf16 , #mma0 >
987+ tt.return
988+ }
989+ }
990+
991+ // -----
992+
993+ #mma0 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 64 , 16 ]}>
994+ #mma1 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 128 , 16 ]}>
995+
996+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 } {
997+ // CHECK-LABEL: convert_layout_mmav3_mmav3_2
998+ tt.func @convert_layout_mmav3_mmav3_2 (%arg0: tensor <16 x16 xf16 , #mma1 >) {
999+ // CHECK-NOT: st.shared
1000+ // CHECK-NOT: llvm.load
1001+ %0 = triton_gpu.convert_layout %arg0 : tensor <16 x16 xf16 , #mma1 > -> tensor <16 x16 xf16 , #mma0 >
1002+ tt.return
1003+ }
1004+ }
1005+
1006+ // -----
1007+
1008+ #mma0 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 64 , 16 ]}>
1009+ #mma1 = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 128 , 16 ]}>
1010+
1011+ module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 4 : i32 } {
1012+ // CHECK-LABEL: convert_layout_mmav3_mmav3_3
1013+ tt.func @convert_layout_mmav3_mmav3_3 (%arg0: tensor <1 x64 xf16 , #mma1 >) {
1014+ // CHECK-NOT: st.shared
1015+ // CHECK-NOT: llvm.load
1016+ %0 = triton_gpu.convert_layout %arg0 : tensor <1 x64 xf16 , #mma1 > -> tensor <1 x64 xf16 , #mma0 >
1017+ tt.return
1018+ }
1019+ }
1020+
1021+ // -----
1022+
9491023#blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
9501024#mma = #triton_gpu.nvidia_mma <{versionMajor = 3 , versionMinor = 0 , warpsPerCTA = [8 , 1 ], instrShape = [16 , 256 , 32 ]}>
9511025module attributes {" triton_gpu.num-ctas" = 1 : i32 , " triton_gpu.num-warps" = 8 : i32 } {
0 commit comments