@@ -852,18 +852,77 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
852852// -----
853853
854854#blocked0 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
855- #shared0 = #ttg.shared <{vec = 1 , perPhase =2 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
855+ #shared0 = #ttg.shared <{vec = 1 , perPhase =1 , maxPhase =1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
856+ #mma0 = #ttg.nvidia_mma <{versionMajor = 2 , warpsPerCTA = [1 , 1 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
857+ #dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#mma0 , kWidth =2 }>
858+ #dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#mma0 , kWidth =2 }>
859+ #smem = #ttg.shared_memory
860+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 } {
861+ // CHECK-LABEL: convert_dot_ldmatrix
862+ tt.func @convert_dot_ldmatrix (%A: tensor <16 x16 xf16 , #blocked0 >, %B: tensor <16 x16 xf16 , #blocked0 >) {
863+ %AA = ttg.local_alloc %A : (tensor <16 x16 xf16 , #blocked0 >) -> !ttg.memdesc <16 x16 xf16 , #shared0 , #smem >
864+ %BB = ttg.local_alloc %B : (tensor <16 x16 xf16 , #blocked0 >) -> !ttg.memdesc <16 x16 xf16 , #shared0 , #smem >
865+ // CHECK: nvgpu.ldmatrix
866+ // CHECK: nvgpu.ldmatrix
867+ // CHECK-NOT: nvgpu.ldmatrix
868+ %AA_DOT = ttg.local_load %AA : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_a >
869+ %BB_DOT = ttg.local_load %BB : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_b >
870+ %cst0 = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #mma0 >
871+
872+ // CHECK: llvm.inline_asm
873+ // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
874+ // CHECK: llvm.inline_asm
875+ // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
876+ %D = tt.dot %AA_DOT , %BB_DOT , %cst0 : tensor <16 x16 xf16 , #dot_operand_a > * tensor <16 x16 xf16 , #dot_operand_b > -> tensor <16 x16 xf32 , #mma0 >
877+
878+ tt.return
879+ }
880+ }
881+
882+ // -----
883+
884+ #blocked0 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
885+ #shared0 = #ttg.shared <{vec = 8 , perPhase =1 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
856886#mma0 = #ttg.nvidia_mma <{versionMajor = 2 , warpsPerCTA = [1 , 1 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
857887#dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#mma0 , kWidth =2 }>
858888#dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#mma0 , kWidth =2 }>
859889#smem = #ttg.shared_memory
860890module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 } {
861891 // CHECK-LABEL: convert_dot
862- tt.func @convert_dot (%A: tensor <16 x16 xf16 , #blocked0 >, %B: tensor <16 x16 xf16 , #blocked0 >) {
892+ tt.func @convert_dot_ldmatrix_swizzle (%A: tensor <16 x16 xf16 , #blocked0 >, %B: tensor <16 x16 xf16 , #blocked0 >) {
863893 %AA = ttg.local_alloc %A : (tensor <16 x16 xf16 , #blocked0 >) -> !ttg.memdesc <16 x16 xf16 , #shared0 , #smem >
864894 %BB = ttg.local_alloc %B : (tensor <16 x16 xf16 , #blocked0 >) -> !ttg.memdesc <16 x16 xf16 , #shared0 , #smem >
865895 // CHECK: nvgpu.ldmatrix
866896 // CHECK: nvgpu.ldmatrix
897+ // CHECK-NOT: nvgpu.ldmatrix
898+ %AA_DOT = ttg.local_load %AA : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_a >
899+ %BB_DOT = ttg.local_load %BB : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_b >
900+ %cst0 = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #mma0 >
901+
902+ // CHECK: llvm.inline_asm
903+ // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
904+ // CHECK: llvm.inline_asm
905+ // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
906+ %D = tt.dot %AA_DOT , %BB_DOT , %cst0 : tensor <16 x16 xf16 , #dot_operand_a > * tensor <16 x16 xf16 , #dot_operand_b > -> tensor <16 x16 xf32 , #mma0 >
907+
908+ tt.return
909+ }
910+ }
911+
912+ // -----
913+
914+ #blocked0 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
915+ #shared0 = #ttg.shared <{vec = 1 , perPhase =1 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
916+ #mma0 = #ttg.nvidia_mma <{versionMajor = 2 , warpsPerCTA = [1 , 1 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
917+ #dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#mma0 , kWidth =2 }>
918+ #dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#mma0 , kWidth =2 }>
919+ #smem = #ttg.shared_memory
920+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 1 : i32 } {
921+ // CHECK-LABEL: convert_dot
922+ tt.func @convert_dot (%A: tensor <16 x16 xf16 , #blocked0 >, %B: tensor <16 x16 xf16 , #blocked0 >) {
923+ %AA = ttg.local_alloc %A : (tensor <16 x16 xf16 , #blocked0 >) -> !ttg.memdesc <16 x16 xf16 , #shared0 , #smem >
924+ %BB = ttg.local_alloc %B : (tensor <16 x16 xf16 , #blocked0 >) -> !ttg.memdesc <16 x16 xf16 , #shared0 , #smem >
925+ // CHECK-NOT: nvgpu.ldmatrix
867926 %AA_DOT = ttg.local_load %AA : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_a >
868927 %BB_DOT = ttg.local_load %BB : !ttg.memdesc <16 x16 xf16 , #shared0 , #smem > -> tensor <16 x16 xf16 , #dot_operand_b >
869928 %cst0 = arith.constant dense <0.000000e+00 > : tensor <16 x16 xf32 , #mma0 >
@@ -905,7 +964,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
905964// -----
906965
907966#blocked0 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
908- #shared0 = #ttg.shared <{vec = 1 , perPhase =2 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
967+ #shared0 = #ttg.shared <{vec = 16 , perPhase =1 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
909968#mma0 = #ttg.nvidia_mma <{versionMajor = 2 , warpsPerCTA = [1 , 1 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
910969#dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#mma0 , kWidth =4 }>
911970#dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#mma0 , kWidth =4 }>
@@ -1206,7 +1265,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
12061265// -----
12071266
12081267#blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
1209- #shared = #ttg.shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
1268+ #shared = #ttg.shared <{vec = 8 , perPhase = 1 , maxPhase = 8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
12101269#mma = #ttg.nvidia_mma <{versionMajor = 2 , warpsPerCTA = [2 , 2 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
12111270#dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#mma , kWidth =2 }>
12121271#dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#mma , kWidth =2 }>
@@ -1255,7 +1314,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12551314// -----
12561315
12571316#mma = #ttg.nvidia_mma <{versionMajor =2 , warpsPerCTA =[2 , 2 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
1258- #shared = #ttg.shared <{vec = 1 , perPhase = 1 , maxPhase = 1 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
1317+ #shared = #ttg.shared <{vec = 4 , perPhase = 1 , maxPhase = 4 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
12591318#blocked = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [2 , 16 ], warpsPerCTA = [1 , 4 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
12601319#dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#mma , kWidth =1 }>
12611320#dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#mma , kWidth =1 }>
@@ -1744,7 +1803,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
17441803// -----
17451804
17461805#blocked0 = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [8 , 4 ], warpsPerCTA = [1 , 1 ], order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
1747- #shared0 = #ttg.shared <{vec = 1 , perPhase =2 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
1806+ #shared0 = #ttg.shared <{vec = 8 , perPhase =1 , maxPhase =8 , order = [1 , 0 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [1 , 0 ]}>
17481807#mma = #ttg.nvidia_mma <{versionMajor = 2 , warpsPerCTA = [1 , 1 ], CTAsPerCGA = [1 , 1 ], CTASplitNum = [1 , 1 ], CTAOrder = [0 , 1 ], instrShape = [16 , 8 ]}>
17491808#dot_operand_a = #ttg.dot_op <{opIdx =0 , parent =#mma , kWidth =2 }>
17501809#dot_operand_b = #ttg.dot_op <{opIdx =1 , parent =#mma , kWidth =2 }>
0 commit comments