@@ -70,8 +70,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
7070module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 } {
7171 // COM: tt.load -> tt.trans -> tt.dot chain, in a loop.
7272 // COM: where the 'make_tensor_ptr' result is loop carried.
73- tt.func public @fuseLoadWithTrans3 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr <f32 >) {
74- %c4_i32 = arith.constant 4 : i32
73+ tt.func public @fuseLoadWithTrans3 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
7574 %c1024_i32 = arith.constant 1024 : i32
7675 %c0_i32 = arith.constant 0 : i32
7776 %c32_i32 = arith.constant 32 : i32
@@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
8079 %c1_i64 = arith.constant 1 : i64
8180 %c1024_i64 = arith.constant 1024 : i64
8281 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
83- %0 = tt.get_program_id x : i32
84- %1 = arith.divsi %0 , %c16_i32 : i32
85- %2 = arith.muli %1 , %c4_i32 : i32
86- %3 = arith.subi %c4_i32 , %2 : i32
87- %4 = arith.minsi %3 , %c4_i32 : i32
88- %5 = arith.remsi %0 , %c16_i32 : i32
89- %6 = arith.remsi %5 , %4 : i32
90- %7 = arith.addi %2 , %6 : i32
91- %8 = arith.divsi %5 , %4 : i32
9282 %9 = tt.make_tensor_ptr %arg0 , [%c1024_i64 , %c1024_i64 ], [%c1024_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
9383 %10 = tt.make_tensor_ptr %arg1 , [%c1024_i64 , %c1_i64 ], [%c1_i64 , %c1024_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #linear >>
9484 %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args (%arg4 = %cst , %arg5 = %c0_i32 , %arg6 = %10 ) -> (tensor <256 x256 xf32 , #mma >, i32 , !tt.ptr <tensor <256 x32 xbf16 , #linear >>) : i32 {
@@ -116,13 +106,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
116106
117107// -----
118108
109+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [16 , 0 ], [0 , 16 ], [0 , 32 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ]], warp = [[0 , 0 ], [0 , 0 ]], block = []}>
110+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [4 , 1 ], repCluster = [2 , 2 ], A = [16 , 16 ], B = [16 , 32 ], C = [16 , 32 ]}>
111+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 16 : i32 } {
112+ // COM: tt.load -> tt.trans -> tt.dot chain, in 2 loops.
113+ // COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
114+ tt.func public @fuseLoadWithTrans4 (%arg0: i32 , %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >) {
115+ %c32_i32 = arith.constant 32 : i32
116+ %c0_i32 = arith.constant 0 : i32
117+ %c64_i64 = arith.constant 64 : i64
118+ %c1_i64 = arith.constant 1 : i64
119+ %cst_3 = arith.constant dense <0.000000e+00 > : tensor <64 x32 xf32 , #mma >
120+ %7 = tt.make_tensor_ptr %arg1 , [%c1_i64 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
121+ %9 = tt.make_tensor_ptr %arg2 , [%c1_i64 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x64 xf16 , #linear >>
122+ %24 = tt.advance %7 , [%arg0 , %c0_i32 ] : <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
123+ %25 = tt.load %24 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
124+ %29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg13 = %arg0 ) -> (i32 ) : i32 {
125+ %adv1 = tt.advance %9 , [%arg13 , %c0_i32 ] : <tensor <32 x64 xf16 , #linear >>
126+ %load1 = tt.load %adv1 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x64 xf16 , #linear >>
127+ %trans1 = tt.trans %load1 {order = array<i32 : 1 , 0 >} : tensor <32 x64 xf16 , #linear > -> tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
128+ %dot1 = tt.dot %25 , %trans1 , %cst_3 , inputPrecision = tf32 : tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x32 xf32 , #mma >
129+ %76 = arith.addi %arg13 , %c32_i32 : i32
130+ scf.yield %76 : i32
131+ }
132+ %38:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg13 = %arg0 ) -> (i32 ) : i32 {
133+ %adv2 = tt.advance %9 , [%arg13 , %c0_i32 ] : <tensor <32 x64 xf16 , #linear >>
134+ %load2 = tt.load %adv2 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x64 xf16 , #linear >>
135+ %trans2 = tt.trans %load2 {order = array<i32 : 1 , 0 >} : tensor <32 x64 xf16 , #linear > -> tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
136+ %dot2 = tt.dot %25 , %trans2 , %cst_3 , inputPrecision = tf32 : tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x32 xf32 , #mma >
137+ %81 = arith.addi %arg13 , %c32_i32 : i32
138+ scf.yield %81 : i32
139+ }
140+ tt.return
141+ }
142+ // CHECK-LABEL: fuseLoadWithTrans4
143+ // CHECK-NOT: tt.trans
144+ // CHECK-COUNT-2: tt.make_tensor_ptr %arg2, [%c64_i64, %c1_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
145+ // CHECK: scf.for {{.*}}
146+ // CHECK: [[ADV1:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
147+ // CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV1]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
148+ // CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
149+ // CHECK: scf.yield {{.*}}
150+ // CHECK: scf.for {{.*}}
151+ // CHECK: [[ADV2:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
152+ // CHECK: [[LOAD_B2:%.*]] = tt.load [[ADV2]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
153+ // CHECK: tt.dot {{.*}}, [[LOAD_B2]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
154+ // CHECK: scf.yield {{.*}}
155+ }
156+
157+ // -----
158+
119159#linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [16 , 0 ], [0 , 16 ], [128 , 0 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ]], warp = [[32 , 0 ], [64 , 0 ], [0 , 0 ], [0 , 0 ], [0 , 0 ]], block = []}>
120160#mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [8 , 4 ], repCluster = [4 , 2 ], A = [32 , 16 ], B = [16 , 32 ], C = [32 , 32 ]}>
121161module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 } {
122- // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load
123- // COM: that 'feeds' the transpose operation is used.
124- tt.func public @doNotFuseLoadWithTrans1 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr <f32 >) {
125- %c4_i32 = arith.constant 4 : i32
162+ // COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load that 'feeds' the transpose operation is used.
163+ tt.func public @doNotFuseLoadWithTrans1 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
126164 %c1024_i32 = arith.constant 1024 : i32
127165 %c0_i32 = arith.constant 0 : i32
128166 %c32_i32 = arith.constant 32 : i32
@@ -131,15 +169,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
131169 %c1_i64 = arith.constant 1 : i64
132170 %c1024_i64 = arith.constant 1024 : i64
133171 %cst = arith.constant dense <0.000000e+00 > : tensor <256 x256 xf32 , #mma >
134- %0 = tt.get_program_id x : i32
135- %1 = arith.divsi %0 , %c16_i32 : i32
136- %2 = arith.muli %1 , %c4_i32 : i32
137- %3 = arith.subi %c4_i32 , %2 : i32
138- %4 = arith.minsi %3 , %c4_i32 : i32
139- %5 = arith.remsi %0 , %c16_i32 : i32
140- %6 = arith.remsi %5 , %4 : i32
141- %7 = arith.addi %2 , %6 : i32
142- %8 = arith.divsi %5 , %4 : i32
143172 %9 = tt.make_tensor_ptr %arg0 , [%c1024_i64 , %c1024_i64 ], [%c1024_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
144173 %10 = tt.make_tensor_ptr %arg1 , [%c1024_i64 , %c1_i64 ], [%c1_i64 , %c1024_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <256 x32 xbf16 , #linear >>
145174 %13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args (%arg4 = %cst , %arg5 = %c0_i32 , %arg6 = %10 ) -> (tensor <256 x256 xf32 , #mma >, i32 , !tt.ptr <tensor <256 x32 xbf16 , #linear >>) : i32 {
@@ -166,7 +195,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
166195module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 } {
167196 // COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose.
168197 // COM: In this case `%19` is used by the load that feeds the transpose and by a store operation.
169- tt.func public @doNotFuseLoadWithTrans2 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >, %arg2: !tt.ptr < f32 > ) {
198+ tt.func public @doNotFuseLoadWithTrans2 (%arg0: !tt.ptr <bf16 >, %arg1: !tt.ptr <bf16 >) {
170199 %c4_i32 = arith.constant 4 : i32
171200 %c1024_i32 = arith.constant 1024 : i32
172201 %c0_i32 = arith.constant 0 : i32
0 commit comments