Skip to content

Commit 59b7b25

Browse files
authored
Always do boundary check on Tensor-Descriptor lowering (#4303)
Closes #4137 #4140 #4221 The PR fixes failures in several tensor descriptor tests. The reason for the failure was that the `tensor_descriptor.store` operation was going out-of-bounds too far and was overwriting the reference array that was allocated nearby. The [`RewriteTensorDescriptorToPointer` pass](https://github.com/triton-lang/triton/blob/main/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp) in upstream always generates boundary checks for both [stores](https://github.com/triton-lang/triton/blob/09dc29800e918d5f5c8df4279d124f51e0a94987/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp#L286) and [loads](https://github.com/triton-lang/triton/blob/09dc29800e918d5f5c8df4279d124f51e0a94987/lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp#L263) using a boolean mask. The intel-specific lowering performs the following conversion without adding any boundary checks: ``` tensor_descriptor --(without bound-check)--> tensor_pointer --(without bound-check)--> llvm.load/store ``` The boolean mask generation based on the tensor shape is [supposed to happen](https://github.com/intel/intel-xpu-backend-for-triton/blob/ae46511660d6a699132c67387394f511e825c90f/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp#L2558-L2564) at the `LoadStoreOpToLLVM` conversion pass. The code [relies on the `boundaryCheck`](https://github.com/intel/intel-xpu-backend-for-triton/blob/ae46511660d6a699132c67387394f511e825c90f/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp#L2564) attribute of `tt.StoreOp` that is not set during `tensor_descriptor --> tensor_pointer` conversion. This PR fixes the problem and adds a `boundaryCheck` attribute for every load/store operation in the `TensorDescToBlockPointer` pass --------- Signed-off-by: dchigarev <[email protected]>
1 parent 9bb257c commit 59b7b25

File tree

4 files changed

+11
-13
lines changed

4 files changed

+11
-13
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
209209
@pytest.mark.parametrize("dtype_str", tma_dtypes)
210210
@pytest.mark.parametrize("K_BLOCK", [16, 32, 64, 128])
211211
def test_tensor_descriptor_store3d(dtype_str, K_BLOCK, device):
212-
if is_xpu() and dtype_str == 'bfloat16':
213-
pytest.skip("FIXME: issue #4137")
214212

215213
@triton.jit
216214
def kernel(out_ptr, a_ptr, M, N, K, stride_m, stride_n, stride_k, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr,
@@ -329,8 +327,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
329327
def test_tensor_descriptor_store_nd(dtype_str, num_ctas, ndim, INNER_BLOCK, device):
330328
if num_ctas == 2 and (not is_cuda() or torch.cuda.get_device_capability(0)[0] not in (9, 10)):
331329
pytest.xfail("CTAs is unsupported for these cards")
332-
if is_xpu() and ndim not in [1]:
333-
pytest.skip("FIXME: issue #4140")
334330

335331
@triton.jit
336332
def kernel(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE):
@@ -926,8 +922,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
926922
@pytest.mark.parametrize("ndim", [3, 4, 5])
927923
@pytest.mark.parametrize("INNER_BLOCK", [16, 32, 64, 128])
928924
def test_tensor_descriptor_rank_reducing_load(dtype_str, ndim, INNER_BLOCK, device):
929-
if is_xpu():
930-
pytest.skip("FIXME: issue #4221")
931925

932926
@triton.jit
933927
def kernel(out_ptr, a_ptr, shape, strides, BLOCK_SHAPE):

test/Triton/Intel/TensorDescToBlockPointer/basic.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ module {
1919
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
2020
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
2121
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
22-
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr<tensor<16x128xf32>>
22+
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
2323
// CHECK: tt.return
2424
// CHECK: }
2525

@@ -43,7 +43,7 @@ module {
4343
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
4444
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
4545
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<16x128xf32>>
46-
// CHECK: tt.store [[TENSOR_PTR]], [[CST]] : !tt.ptr<tensor<16x128xf32>>
46+
// CHECK: tt.store [[TENSOR_PTR]], [[CST]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x128xf32>>
4747
// CHECK: tt.return
4848
// CHECK: }
4949
}

test/Triton/Intel/TensorDescToBlockPointer/loop.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ module {
3131
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
3232
// CHECK-DAG: [[EXTSI_PARAM_2b:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
3333
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2b]]], {{\[}}[[EXTSI_PARAM_2a]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[IDX_CAST_1]]] {{.*}} : <tensor<16x32xf16>>
34-
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr<tensor<16x32xf16>>
34+
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x32xf16>>
3535
// CHECK: [[ADD:%.+]] = arith.addf [[VAR_arg2]], [[LOAD]] : tensor<16x32xf16>
3636
// CHECK: scf.yield {{.*}}, [[ADD]] : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>
3737
// CHECK: }
@@ -124,7 +124,7 @@ module {
124124
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
125125
// CHECK-DAG: [[EXTSI_PARAM_2b:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
126126
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2b]]], {{\[}}[[EXTSI_PARAM_2a]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[IDX_CAST_1]]] {{.*}} : <tensor<16x32xf16>>
127-
// CHECK: tt.store [[TENSOR_PTR]], [[VAR_arg2]] : !tt.ptr<tensor<16x32xf16>>
127+
// CHECK: tt.store [[TENSOR_PTR]], [[VAR_arg2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x32xf16>>
128128
// CHECK: [[ADD:%.+]] = arith.addf [[VAR_arg2]], [[CST]] : tensor<16x32xf16>
129129
// CHECK: scf.yield {{.*}}, [[ADD]] : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>
130130
// CHECK: }

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,21 @@ struct TritonIntelTensorDescToBlockPointer
202202
llvm::dbgs().indent(2) << makeTensorPtrOp << "\n";
203203
});
204204

205+
SmallVector<int32_t> boundaryCheck;
206+
for (size_t i = 0; i < makeTensorDescOp.getShape().size(); ++i)
207+
boundaryCheck.push_back(i);
205208
constexpr bool isLoad = std::is_same_v<OpTy, tt::DescriptorLoadOp>;
206209
if constexpr (isLoad) {
207210
auto loadOp = builder.createOrFold<tt::LoadOp>(
208-
loc, makeTensorPtrOp, op.getCache(), op.getEvict(),
211+
loc, makeTensorPtrOp, boundaryCheck, /*padding*/ std::nullopt,
212+
op.getCache(), op.getEvict(),
209213
/*volatile*/ false);
210214
LLVM_DEBUG(llvm::dbgs().indent(2) << loadOp << "\n");
211215
op.replaceAllUsesWith(loadOp);
212216
} else {
213217
[[maybe_unused]] auto storeOp = builder.createOrFold<tt::StoreOp>(
214-
loc, makeTensorPtrOp, op.getSrc(), tt::CacheModifier::NONE,
215-
tt::EvictionPolicy::NORMAL);
218+
loc, makeTensorPtrOp, op.getSrc(), boundaryCheck,
219+
tt::CacheModifier::NONE, tt::EvictionPolicy::NORMAL);
216220
LLVM_DEBUG(llvm::dbgs().indent(2) << storeOp << "\n");
217221
}
218222

0 commit comments

Comments
 (0)