@@ -303,3 +303,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
303
303
// CHECK-LABEL: doNotFuseLoadWithTrans4
304
304
// CHECK: tt.trans
305
305
}
306
+
307
+ // -----
308
+ #linear = #ttg.linear <{register = [[0 , 1 ], [0 , 2 ], [0 , 4 ], [0 , 8 ], [16 , 0 ], [0 , 16 ], [0 , 32 ]], lane = [[1 , 0 ], [2 , 0 ], [4 , 0 ], [8 , 0 ]], warp = [[0 , 0 ], [0 , 0 ]], block = []}>
309
+ #mma = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [4 , 1 ], repCluster = [2 , 2 ], A = [16 , 16 ], B = [16 , 32 ], C = [16 , 32 ]}>
310
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 } {
311
+ // COM: Ensure tt.trans is not fused with tt.load when the load uses a pointer yielded by a function call.
312
+ tt.func @func (%cond: i1 , %p1: !tt.ptr <tensor <32 x64 xf16 , #linear >>, %p2: !tt.ptr <tensor <32 x64 xf16 , #linear >>) -> !tt.ptr <tensor <32 x64 xf16 , #linear >> attributes {noinline = true } {
313
+ %0 = arith.select %cond , %p1 , %p2 : i1 , !tt.ptr <tensor <32 x64 xf16 , #linear >>
314
+ tt.return %0 : !tt.ptr <tensor <32 x64 xf16 , #linear >>
315
+ }
316
+ tt.func public @doNotFuseLoadWithTrans5 (%arg0: i32 , %arg1: !tt.ptr <f16 >, %arg2: !tt.ptr <f16 >, %cond: i1 ) {
317
+ %c32_i32 = arith.constant 32 : i32
318
+ %c0_i32 = arith.constant 0 : i32
319
+ %c64_i64 = arith.constant 64 : i64
320
+ %c1_i64 = arith.constant 1 : i64
321
+ %cst_3 = arith.constant dense <0.000000e+00 > : tensor <64 x32 xf32 , #mma >
322
+ %7 = tt.make_tensor_ptr %arg1 , [%c1_i64 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
323
+ %9 = tt.make_tensor_ptr %arg2 , [%c1_i64 , %c64_i64 ], [%c64_i64 , %c1_i64 ], [%c0_i32 , %c0_i32 ] {order = array<i32 : 1 , 0 >} : <tensor <32 x64 xf16 , #linear >>
324
+ %24 = tt.advance %7 , [%arg0 , %c0_i32 ] : <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
325
+ %25 = tt.load %24 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>>>
326
+ %29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args (%arg13 = %arg0 ) -> (i32 ) : i32 {
327
+ %adv1 = tt.advance %9 , [%arg13 , %c0_i32 ] : <tensor <32 x64 xf16 , #linear >>
328
+ %adv2 = tt.advance %9 , [%c0_i32 , %arg13 ] : <tensor <32 x64 xf16 , #linear >>
329
+ %adv3 = tt.call @func (%cond , %adv1 , %adv2 ) : (i1 , !tt.ptr <tensor <32 x64 xf16 , #linear >>, !tt.ptr <tensor <32 x64 xf16 , #linear >>) -> !tt.ptr <tensor <32 x64 xf16 , #linear >>
330
+ %load1 = tt.load %adv3 {boundaryCheck = array<i32 : 0 , 1 >, ttig.block_io = " row_major" } : !tt.ptr <tensor <32 x64 xf16 , #linear >>
331
+ %trans1 = tt.trans %load1 {order = array<i32 : 1 , 0 >} : tensor <32 x64 xf16 , #linear > -> tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>>
332
+ %dot1 = tt.dot %25 , %trans1 , %cst_3 , inputPrecision = tf32 : tensor <64 x64 xf16 , #ttg.dot_op <{opIdx = 0 , parent = #mma , kWidth = 1 }>> * tensor <64 x32 xf16 , #ttg.dot_op <{opIdx = 1 , parent = #mma , kWidth = 2 }>> -> tensor <64 x32 xf32 , #mma >
333
+ %76 = arith.addi %arg13 , %c32_i32 : i32
334
+ scf.yield %76 : i32
335
+ }
336
+ tt.return
337
+ }
338
+ // CHECK-LABEL: doNotFuseLoadWithTrans5
339
+ // CHECK: tt.trans
340
+ }
0 commit comments