@@ -171,25 +171,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
171171
172172// -----
173173
174- // COM: Ensure pointer with stride [0, 1] is considered as row major.
174+ // COM: Ensure pointers with strides [0, 1]/[1, 0] are considered row/column major respectively .
175175#blocked = #ttg.blocked <{sizePerThread = [1 , 8 ], threadsPerWarp = [4 , 4 ], warpsPerCTA = [32 , 1 ], order = [1 , 0 ]}>
176176module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 32 : i32 , " ttg.threads-per-warp" = 16 : i32 , ttig.support_sg_2d_block } {
177177 tt.func public @tensor_of_ptr (%arg0: !tt.ptr <bf16 > {tt.divisibility = 16 : i32 }) {
178- %18 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
179- %19 = tt.expand_dims %18 {axis = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x32 xi32 , #blocked >
180- %20 = tt.splat %arg0 : !tt.ptr <bf16 > -> tensor <1 x32 x!tt.ptr <bf16 >, #blocked >
181- %21 = tt.addptr %20 , %19 : tensor <1 x32 x!tt.ptr <bf16 >, #blocked >, tensor <1 x32 xi32 , #blocked >
182- %22 = tt.broadcast %21 : tensor <1 x32 x!tt.ptr <bf16 >, #blocked > -> tensor <256 x32 x!tt.ptr <bf16 >, #blocked >
178+ %0 = tt.make_range {end = 32 : i32 , start = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>>
179+ %1 = tt.expand_dims %0 {axis = 0 : i32 } : tensor <32 xi32 , #ttg.slice <{dim = 0 , parent = #blocked }>> -> tensor <1 x32 xi32 , #blocked >
180+ %2 = tt.splat %arg0 : !tt.ptr <bf16 > -> tensor <1 x32 x!tt.ptr <bf16 >, #blocked >
181+ %3 = tt.addptr %2 , %1 : tensor <1 x32 x!tt.ptr <bf16 >, #blocked >, tensor <1 x32 xi32 , #blocked >
182+ %4 = tt.broadcast %3 : tensor <1 x32 x!tt.ptr <bf16 >, #blocked > -> tensor <256 x32 x!tt.ptr <bf16 >, #blocked >
183183 // CHECK: tt.load {{.*}} {ttig.block_io = "row_major"}
184- %50 = tt.load %22 : tensor <256 x32 x!tt.ptr <bf16 >, #blocked >
184+ tt.load %4 : tensor <256 x32 x!tt.ptr <bf16 >, #blocked >
185+
186+ %6 = tt.make_range {end = 256 : i32 , start = 0 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>>
187+ %7 = tt.expand_dims %6 {axis = 1 : i32 } : tensor <256 xi32 , #ttg.slice <{dim = 1 , parent = #blocked }>> -> tensor <256 x1 xi32 , #blocked >
188+ %8 = tt.splat %arg0 : !tt.ptr <bf16 > -> tensor <256 x1 x!tt.ptr <bf16 >, #blocked >
189+ %9 = tt.addptr %8 , %7 : tensor <256 x1 x!tt.ptr <bf16 >, #blocked >, tensor <256 x1 xi32 , #blocked >
190+ %10 = tt.broadcast %9 : tensor <256 x1 x!tt.ptr <bf16 >, #blocked > -> tensor <256 x32 x!tt.ptr <bf16 >, #blocked >
191+ // CHECK: tt.load {{.*}} {ttig.block_io = "column_major"}
192+ tt.load %10 : tensor <256 x32 x!tt.ptr <bf16 >, #blocked >
185193 tt.return
186194 }
187195}
188196
189197// -----
190198
191199// COM: Ensure i64 element type is supported in materialize block pointer.
192-
193200#dpas = #ttig.dpas <{repeatCount = 8 , systolicDepth = 8 , executionSize = 16 , opsPerChan = 2 , threadsPerWarp = 16 , warpsPerCTA = [4 , 2 ], repCluster = [1 , 1 ], A = [8 , 16 ], B = [16 , 16 ], C = [8 , 16 ]}>
194201#dot_a = #ttg.dot_op <{opIdx = 0 , parent = #dpas , kWidth = 1 }>
195202module attributes {" ttg.num-ctas" = 1 : i32 , ttg.target = " xpu" , " ttg.num-warps" = 8 : i32 , " ttg.threads-per-warp" = 16 : i32 , ttig.support_sg_2d_block } {
0 commit comments