Skip to content

Commit 762ace9

Browse files
authored
[PYTHON][TOOLS] simplify construction of tensor descriptor (#7159)
base dtype may not be (and doesn't need to be) canonicalizable
1 parent 1572ee6 commit 762ace9

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

python/triton/language/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def __init__(self, name):
409409
self.name = name
410410
assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name
411411
self.primitive_bitwidth = get_primitive_bitwidth(name)
412+
self.itemsize = self.primitive_bitwidth // 8
412413
if name in dtype.SINT_TYPES:
413414
self.int_signedness = dtype.SIGNEDNESS.SIGNED
414415
self.int_bitwidth = self.primitive_bitwidth

python/triton/tools/tensor_descriptor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from typing import List, Any
3-
from triton._utils import validate_block_shape, canonicalize_dtype, get_primitive_bitwidth
3+
from triton._utils import validate_block_shape
44

55

66
@dataclass
@@ -18,8 +18,7 @@ def __post_init__(self):
1818
assert rank <= 5, "rank cannot be more than 5"
1919
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
2020
validate_block_shape(self.block_shape)
21-
dtype_str = canonicalize_dtype(self.base.dtype)
22-
elem_bytes = get_primitive_bitwidth(dtype_str) // 8
21+
elem_bytes = self.base.dtype.itemsize
2322
for stride in self.strides[:-1]:
2423
assert (stride * elem_bytes) % 16 == 0, "strides must be 16-byte aligned"
2524
assert self.strides[-1] == 1, "Last dimension must be contiguous"

0 commit comments

Comments
 (0)