Skip to content

Commit 8d3f09e

Browse files
authored
[TensorDesc] Add extra stride validation to interpreter (#7713)
Ref: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 Note there is a formatting error in the docs, all tensor descriptors regardless of dtype need to satisfy: ```cpp globalStrides[0] = globalDim[0] * elementSizeInBytes(tensorDataType) + padding[0]; for (int i = 0; i < tensorRank - 1; i++) { globalStrides[i] = globalStrides[i – 1] * (globalDim[i] + padding[i]); assert(globalStrides[i] >= globalDim[i]); } ```
1 parent d82cfd3 commit 8d3f09e

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

python/triton/runtime/interpreter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,18 @@ def validate(self):
8888
assert self.base.data.item() % 16 == 0, "base must be 16-byte aligned"
8989
assert len(self.strides) == self.ndim
9090
assert len(self.block_shape) == self.ndim
91+
assert self.ndim >= 1, "descriptor cannot be 0 dimensional"
9192

9293
for stride in self.strides[:-1]:
9394
assert stride.data.item() % 16 == 0, "stride must be 16-byte aligned"
9495
assert self.strides[-1].data.item() == 1, "last dim must be contiguous"
96+
for i in range(self.ndim - 1):
97+
stride = self.strides[i].data.item()
98+
prev_stride = self.strides[i + 1].data.item()
99+
prev_size = self.shape[i + 1].data.item()
100+
assert stride >= prev_stride, "strides must be ordered largest to smallest"
101+
assert (stride % prev_stride) == 0, "strides must be even multiples of smaller strides"
102+
assert (stride // prev_stride) >= prev_size, "invalid stride"
95103

96104
def materialize_pointers(self, offsets: List[TensorHandle]):
97105
assert len(offsets) == self.ndim

0 commit comments

Comments
 (0)