@@ -2042,6 +2042,44 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-
2042
2042
2043
2043
// -----
2044
2044
2045
+ // Minimal repro for https://github.com/pytorch/pytorch/issues/154933
2046
+ //
2047
+ // Check that if, during hoisting conversions over ext and broadcast ops,
2048
+ // we see multiple different layouts assigned to the same value, then we
2049
+ // skip propagation of that layout.
2050
+
2051
+ // CHECK-LABEL: @hoist_on_ext_broadcast_mismatch
2052
+ #blockedX = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
2053
+ #blockedY = #ttg.blocked <{sizePerThread = [1 , 4 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [4 , 1 ], order = [1 , 0 ]}>
2054
+ module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , ttg.target = " xpu" } {
2055
+ tt.func public @hoist_on_ext_broadcast_mismatch (%arg0: !tt.ptr <i32 > {tt.divisibility = 16 : i32 }, %arg1: !tt.ptr <i32 > {tt.divisibility = 16 : i32 }) -> tensor <4 x1 xi64 , #blockedY > {
2056
+ %c1_i32 = arith.constant 1 : i32
2057
+ %c4_i32 = arith.constant 4 : i32
2058
+ %c0_i32 = arith.constant 0 : i32
2059
+ %0 = tt.make_range {end = 4 : i32 , start = 0 : i32 } : tensor <4 xi32 , #ttg.slice <{dim = 1 , parent = #blockedX }>>
2060
+ %cast0 = arith.extsi %0 : tensor <4 xi32 , #ttg.slice <{dim = 1 , parent = #blockedX }>> to tensor <4 xi64 , #ttg.slice <{dim = 1 , parent = #blockedX }>>
2061
+ %1 = tt.splat %arg0 : !tt.ptr <i32 > -> tensor <4 x!tt.ptr <i32 >, #ttg.slice <{dim = 1 , parent = #blockedX }>>
2062
+ %2 = tt.expand_dims %cast0 {axis = 1 : i32 } : tensor <4 xi64 , #ttg.slice <{dim = 1 , parent = #blockedX }>> -> tensor <4 x1 xi64 , #blockedX >
2063
+ %3 = tt.addptr %1 , %cast0 : tensor <4 x!tt.ptr <i32 >, #ttg.slice <{dim = 1 , parent = #blockedX }>>, tensor <4 xi64 , #ttg.slice <{dim = 1 , parent = #blockedX }>>
2064
+ %4 = tt.load %3 : tensor <4 x!tt.ptr <i32 >, #ttg.slice <{dim = 1 , parent = #blockedX }>>
2065
+ %5 = tt.reshape %4 : tensor <4 xi32 , #ttg.slice <{dim = 1 , parent = #blockedX }>> -> tensor <4 x1 xi32 , #blockedX >
2066
+ // CHECK: arith.extsi
2067
+ %6 = arith.extsi %5 : tensor <4 x1 xi32 , #blockedX > to tensor <4 x1 xi64 , #blockedX >
2068
+ %7 = arith.addi %2 , %6 : tensor <4 x1 xi64 , #blockedX >
2069
+ // for loop prevents fully hoisting the conversion.
2070
+ %8 = scf.for %arg2 = %c0_i32 to %c4_i32 step %c1_i32 iter_args (%arg3 = %5 ) -> (tensor <4 x1 xi32 , #blockedX >) : i32 {
2071
+ scf.yield %5 : tensor <4 x1 xi32 , #blockedX >
2072
+ }
2073
+ // CHECK: ttg.convert_layout
2074
+ %9 = arith.extsi %8 : tensor <4 x1 xi32 , #blockedX > to tensor <4 x1 xi64 , #blockedX >
2075
+ %10 = arith.addi %7 , %9 : tensor <4 x1 xi64 , #blockedX >
2076
+ %11 = ttg.convert_layout %10 : tensor <4 x1 xi64 , #blockedX > -> tensor <4 x1 xi64 , #blockedY >
2077
+ tt.return %11 : tensor <4 x1 xi64 , #blockedY >
2078
+ }
2079
+ }
2080
+
2081
+ // -----
2082
+
2045
2083
#blocked = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [32 , 1 ], warpsPerCTA = [2 , 1 ], order = [0 , 1 ]}>
2046
2084
#blocked1 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [2 , 1 ], order = [1 , 0 ]}>
2047
2085
#blocked2 = #ttg.blocked <{sizePerThread = [1 , 1 ], threadsPerWarp = [1 , 32 ], warpsPerCTA = [1 , 2 ], order = [0 , 1 ]}>
0 commit comments