File tree Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Expand file tree Collapse file tree 2 files changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -409,6 +409,7 @@ def __init__(self, name):
409
409
self .name = name
410
410
assert name in dtype .SINT_TYPES + dtype .UINT_TYPES + dtype .FP_TYPES + dtype .OTHER_TYPES , name
411
411
self .primitive_bitwidth = get_primitive_bitwidth (name )
412
+ self .itemsize = self .primitive_bitwidth // 8
412
413
if name in dtype .SINT_TYPES :
413
414
self .int_signedness = dtype .SIGNEDNESS .SIGNED
414
415
self .int_bitwidth = self .primitive_bitwidth
Original file line number Diff line number Diff line change 1
1
from dataclasses import dataclass
2
2
from typing import List , Any
3
- from triton ._utils import validate_block_shape , canonicalize_dtype , get_primitive_bitwidth
3
+ from triton ._utils import validate_block_shape
4
4
5
5
6
6
@dataclass
@@ -18,8 +18,7 @@ def __post_init__(self):
18
18
assert rank <= 5 , "rank cannot be more than 5"
19
19
assert self .base .data_ptr () % 16 == 0 , "base must be 16-byte aligned"
20
20
validate_block_shape (self .block_shape )
21
- dtype_str = canonicalize_dtype (self .base .dtype )
22
- elem_bytes = get_primitive_bitwidth (dtype_str ) // 8
21
+ elem_bytes = self .base .dtype .itemsize
23
22
for stride in self .strides [:- 1 ]:
24
23
assert (stride * elem_bytes ) % 16 == 0 , "strides must be 16-byte aligned"
25
24
assert self .strides [- 1 ] == 1 , "Last dimension must be contiguous"
You can’t perform that action at this time.
0 commit comments