Skip to content

Commit 9a77cdb

Browse files
[MaterializeBlockPointer] Handle i64 element type (#4759)
Inductor CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/16429761342 (pass) Fixes #4725 Signed-off-by: Whitney Tsang <[email protected]>
1 parent be3fdef commit 9a77cdb

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

test/TritonIntelGPU/materialize-block-pointer.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
185185
tt.return
186186
}
187187
}
188+
189+
// -----
190+
191+
// COM: Ensure i64 element type is supported in materialize block pointer.
192+
193+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
194+
#dot_a = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth = 1}>
195+
module attributes {"ttg.num-ctas" = 1 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
196+
// CHECK-LABEL: tt.func public @materialize_block_pointer(
197+
tt.func public @materialize_block_pointer(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %pitch: i64 {tt.divisibility = 16 : i32}) {
198+
%c0_i32 = arith.constant 0 : i32
199+
%c0_i64 = arith.constant 0 : i64
200+
%c1_i64 = arith.constant 1 : i64
201+
202+
// CHECK: tt.load {{.*}} {ttig.block_io = "row_major"}
203+
%0 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%pitch, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xi64, #dot_a>>
204+
%1 = tt.load %0 : !tt.ptr<tensor<64x32xi64, #dot_a>>
205+
// CHECK: tt.store {{.*}} {ttig.block_io = "row_major"}
206+
tt.store %0, %1 : !tt.ptr<tensor<64x32xi64, #dot_a>>
207+
208+
tt.return
209+
}
210+
}

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
116116
Value pitch =
117117
strides[(strideOneDimVal == rank - 1) ? rank - 2 : rank - 1];
118118
LDBG("Pitch: " << pitch);
119-
if (!ttgi::isDivisible(pitch, 128 / elementWidth))
119+
if (!ttgi::isDivisible(pitch, llvm::divideCeil(128, elementWidth)))
120120
return;
121121

122122
const bool isRowMajor = (strideOneDimVal == rank - 1);
@@ -336,7 +336,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
336336
// Analyze the shape of the stride one dimension to ensure it satisfies HW
337337
// constraints.
338338
Value baseWidth = tt::intel::getFinalValue(shape[strideOneDimVal]);
339-
unsigned divisor = std::ceil(32 / elementWidth);
339+
unsigned divisor = llvm::divideCeil(32, elementWidth);
340340
if (!ttgi::isDivisible(baseWidth, divisor)) {
341341
LLVM_DEBUG({
342342
llvm::dbgs() << "baseWidth does not satisfies HW constraint: ";

0 commit comments

Comments
 (0)