Skip to content

Commit 42793a2

Browse files
authored
[MaterializeBlockPointer]: Set blockIO attribute for column major accesses (#5066)
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 632d234 commit 42793a2

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

test/TritonIntelGPU/materialize-block-pointer.mlir

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,25 +171,32 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
171171

172172
// -----
173173

174-
// COM: Ensure pointer with stride [0, 1] is considered as row major.
174+
// COM: Ensure pointers with strides [0, 1]/[1, 0] are considered row/column major respectively.
175175
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
176176
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
177177
tt.func public @tensor_of_ptr(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}) {
178-
%18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
179-
%19 = tt.expand_dims %18 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
180-
%20 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<1x32x!tt.ptr<bf16>, #blocked>
181-
%21 = tt.addptr %20, %19 : tensor<1x32x!tt.ptr<bf16>, #blocked>, tensor<1x32xi32, #blocked>
182-
%22 = tt.broadcast %21 : tensor<1x32x!tt.ptr<bf16>, #blocked> -> tensor<256x32x!tt.ptr<bf16>, #blocked>
178+
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
179+
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x32xi32, #blocked>
180+
%2 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<1x32x!tt.ptr<bf16>, #blocked>
181+
%3 = tt.addptr %2, %1 : tensor<1x32x!tt.ptr<bf16>, #blocked>, tensor<1x32xi32, #blocked>
182+
%4 = tt.broadcast %3 : tensor<1x32x!tt.ptr<bf16>, #blocked> -> tensor<256x32x!tt.ptr<bf16>, #blocked>
183183
// CHECK: tt.load {{.*}} {ttig.block_io = "row_major"}
184-
%50 = tt.load %22 : tensor<256x32x!tt.ptr<bf16>, #blocked>
184+
tt.load %4 : tensor<256x32x!tt.ptr<bf16>, #blocked>
185+
186+
%6 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
187+
%7 = tt.expand_dims %6 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked>
188+
%8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x1x!tt.ptr<bf16>, #blocked>
189+
%9 = tt.addptr %8, %7 : tensor<256x1x!tt.ptr<bf16>, #blocked>, tensor<256x1xi32, #blocked>
190+
%10 = tt.broadcast %9 : tensor<256x1x!tt.ptr<bf16>, #blocked> -> tensor<256x32x!tt.ptr<bf16>, #blocked>
191+
// CHECK: tt.load {{.*}} {ttig.block_io = "column_major"}
192+
tt.load %10 : tensor<256x32x!tt.ptr<bf16>, #blocked>
185193
tt.return
186194
}
187195
}
188196

189197
// -----
190198

191199
// COM: Ensure i64 element type is supported in materialize block pointer.
192-
193200
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
194201
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>
195202
module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,17 @@ struct TritonIntelGPUMaterializeBlockPointerPass
217217
return true;
218218
};
219219

220+
const StringRef blockIOAttrName =
221+
ttgi::TritonIntelGPUDialect::getBlockIOAttrName();
220222
const bool isRowMajor = isMajor(tensorTy, 1 /*fastChangeDim*/, *axisInfo);
221223
if (isRowMajor)
222-
op->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
224+
op->setAttr(blockIOAttrName,
223225
StringAttr::get(op.getContext(), "row_major"));
226+
227+
const bool isColMajor = isMajor(tensorTy, 0 /*fastChangeDim*/, *axisInfo);
228+
if (isColMajor)
229+
op->setAttr(blockIOAttrName,
230+
StringAttr::get(op.getContext(), "column_major"));
224231
}
225232

226233
// Return the load layout if it is a dot layout. If it is not, check if the

0 commit comments

Comments
 (0)