Skip to content

Commit 59edf2c

Browse files
[convertBlockPtrToTensorOfPtr] Do not generate mask when boundary_check is not set (#2366)
`RewriteTensorPointer` also doesn't generate mask when `boundary_check` is not set: https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp#L206 Signed-off-by: Whitney Tsang <[email protected]>
1 parent db07b9e commit 59edf2c

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
261261
#dot_b = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
262262
module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
263263
// CHECK-LABEL: llvm.func spir_kernelcc @non_contiguous_load_dot_layout
264+
// COM: Check mask is not generated when boundary_check is not set.
265+
// CHECK-NOT: llvm.icmp "slt"
264266
tt.func public @non_contiguous_load_dot_layout(%arg0: !tt.ptr<f16>, %col_stride: i64) {
265267
%c64_i64 = arith.constant 64 : i64
266268
%c1_i64 = arith.constant 1 : i64

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ struct LoadStoreConversionBase {
230230
};
231231

232232
SmallVector<Value> ptrElems(numElems);
233-
SmallVector<Value> maskElems(numElems);
233+
SmallVector<Value> maskElems;
234234
for (unsigned i = 0; i < numElems; ++i) {
235235
auto index = indices[i];
236236
SmallVector<Value> indicesInTensor(rank);
@@ -251,15 +251,17 @@ struct LoadStoreConversionBase {
251251
ptrElems[i] = gep(ptr_ty(rewriter.getContext(), 1 /*global*/),
252252
valueElemTy, blockPtr[blockBase], offset);
253253

254-
// Get the LLVM values for mask
255-
maskElems[i] = linearize(
256-
indicesInTensor,
257-
{blockPtr.begin() + blockShape, blockPtr.begin() + blockStride},
258-
int_val(1, 1),
259-
[&](const Value &index, const Value &shape, const Value &mask) {
260-
// mask = mask && (index < shape)
261-
return and_(icmp_slt(index, trunc(i32_ty, shape)), mask);
262-
});
254+
if (boundaryCheck.size() > 0) {
255+
// Get the LLVM values for mask
256+
maskElems.push_back(linearize(
257+
indicesInTensor,
258+
{blockPtr.begin() + blockShape, blockPtr.begin() + blockStride},
259+
int_val(1, 1),
260+
[&](const Value &index, const Value &shape, const Value &mask) {
261+
// mask = mask && (index < shape)
262+
return and_(icmp_slt(index, trunc(i32_ty, shape)), mask);
263+
}));
264+
}
263265
}
264266

265267
// Get the LLVM values for `other`

0 commit comments

Comments
 (0)