@@ -160,3 +160,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
160160 tt.return %6 : tensor <128 x16 xf32 , #mma >
161161 }
162162}
163+
164+ // -----
165+
166+ #blocked = #ttg.blocked <{sizePerThread = [8 , 1 ], threadsPerWarp = [8 , 8 ], warpsPerCTA = [1 , 8 ], order = [0 , 1 ]}>
167+ #mma = #ttg.amd_mfma <{version = 3 , warpsPerCTA = [8 , 1 ], instrShape = [16 , 16 , 16 ], isTransposed = true }>
168+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 8 : i32 , ttg.target = " hip:gfx950" , " ttg.threads-per-warp" = 64 : i32 } {
169+ // CHECK-LABEL: tt.func @chained_dots_with_load_bias_in_between
170+
171+ // Similar to the previous test but load bias tensor bewteen 2 dots
172+ // We expect the unstreamable load can be kept after pipelining
173+
174+ // CHECK: scf.for
175+ // CHECK: tt.dot
176+ // CHECK: ttg.async_copy_global_to_local
177+ // CHECK: tt.dot
178+ // CHECK: ttg.async_wait
179+ // CHECK: ttg.local_load
180+ // CHECK: tt.load
181+ // CHECK: scf.yield
182+
183+ tt.func @chained_dots_with_load_bias_in_between (%arg0: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg1: tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>, %arg2: i64 {tt.divisibility = 16 : i32 }, %arg3: !tt.ptr <f16 > {tt.divisibility = 16 : i32 }, %arg4: i32 ) -> tensor <256 x64 xf32 , #mma > {
184+ %c0_i32 = arith.constant 0 : i32
185+ %c1_i32 = arith.constant 1 : i32
186+ %c64_i32 = arith.constant 64 : i32
187+ %cst = arith.constant dense <0.000000e+00 > : tensor <256 x64 xf32 , #mma >
188+ %0 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
189+ %1 = tt.expand_dims %0 {axis = 1 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <64 x1 xi32 , #blocked >
190+ %2 = tt.splat %arg0 : !tt.ptr <f16 > -> tensor <64 x64 x!tt.ptr <f16 >, #blocked >
191+ %3 = tt.broadcast %1 : tensor <64 x1 xi32 , #blocked > -> tensor <64 x64 xi32 , #blocked >
192+ %4 = tt.addptr %2 , %3 : tensor <64 x64 x!tt.ptr <f16 >, #blocked >, tensor <64 x64 xi32 , #blocked >
193+ %5 = tt.make_range {end = 64 : i32 , start = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
194+ %6 = tt.splat %arg3 : !tt.ptr <f16 > -> tensor <256 x64 x!tt.ptr <f16 >, #blocked >
195+ %7 = scf.for %arg5 = %c0_i32 to %arg4 step %c1_i32 iter_args (%arg6 = %cst ) -> (tensor <256 x64 xf32 , #mma >) : i32 {
196+ %8 = tt.load %4 : tensor <64 x64 x!tt.ptr <f16 >, #blocked >
197+ %9 = ttg.convert_layout %8 : tensor <64 x64 xf16 , #blocked > -> tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>>
198+ %10 = tt.dot %arg1 , %9 , %cst : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <256 x64 xf32 , #mma >
199+ %11 = arith.muli %arg5 , %c64_i32 : i32
200+ %12 = tt.splat %11 : i32 -> tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
201+ %13 = arith.addi %12 , %5 : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
202+ %14 = tt.expand_dims %13 {axis = 0 : i32 } : tensor <64 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x64 xi32 , #blocked >
203+ %15 = tt.broadcast %14 : tensor <1 x64 xi32 , #blocked > -> tensor <256 x64 xi32 , #blocked >
204+ %bias_ptr = tt.addptr %6 , %15 : tensor <256 x64 x!tt.ptr <f16 >, #blocked >, tensor <256 x64 xi32 , #blocked >
205+ %bias = tt.load %bias_ptr : tensor <256 x64 x!tt.ptr <f16 >, #blocked >
206+ %bias_mma = ttg.convert_layout %bias : tensor <256 x64 xf16 , #blocked > -> tensor <256 x64 xf16 , #mma >
207+ %bias_f32 = arith.extf %bias_mma : tensor <256 x64 xf16 , #mma > to tensor <256 x64 xf32 , #mma >
208+ %dot_bias = arith.addf %10 , %bias_f32 : tensor <256 x64 xf32 , #mma >
209+ %21 = arith.truncf %dot_bias : tensor <256 x64 xf32 , #mma > to tensor <256 x64 xf16 , #mma >
210+ %22 = ttg.convert_layout %21 : tensor <256 x64 xf16 , #mma > -> tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>>
211+ %23 = tt.dot %22 , %9 , %arg6 : tensor <256 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 4 }>> * tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 4 }>> -> tensor <256 x64 xf32 , #mma >
212+ scf.yield %23 : tensor <256 x64 xf32 , #mma >
213+ }
214+ tt.return %7 : tensor <256 x64 xf32 , #mma >
215+ }
216+ }
0 commit comments