|
13 | 13 |
|
14 | 14 | from .._C.libtriton import ir |
15 | 15 | from . import semantic |
| 16 | +from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape |
16 | 17 |
|
17 | 18 | T = TypeVar('T') |
18 | 19 |
|
19 | | -TRITON_MAX_TENSOR_NUMEL = 1048576 |
20 | | - |
21 | 20 | TRITON_BUILTIN = "__triton_builtin__" |
22 | 21 |
|
23 | 22 | PropagateNan = ir.PROPAGATE_NAN |
@@ -612,18 +611,11 @@ def __init__(self, element_ty: dtype, shape: List): |
612 | 611 | # while tensor's shape is a list of constexpr. |
613 | 612 |
|
614 | 613 | # shape can be empty ([]) when an input is a 0D tensor. |
615 | | - if not shape: |
| 614 | + self.shape = _unwrap_shape(shape) |
| 615 | + if not self.shape: |
616 | 616 | raise TypeError('0d block_type is forbidden') |
617 | | - if isinstance(shape[0], constexpr): |
618 | | - shape = [s.value for s in shape] |
619 | | - |
620 | | - self.shape = shape |
621 | | - self.numel = 1 |
622 | | - for s in self.shape: |
623 | | - self.numel *= s |
624 | | - if self.numel > TRITON_MAX_TENSOR_NUMEL: |
625 | | - raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") |
626 | 617 |
|
| 618 | + self.numel = validate_block_shape(self.shape) |
627 | 619 | self.name = f'<{self.shape}, {self.element_ty}>' |
628 | 620 |
|
629 | 621 | def to_ir(self, builder: ir.builder) -> ir.block_type: |
@@ -1208,18 +1200,15 @@ def arange(start, end, _builder=None): |
1208 | 1200 | """ |
1209 | 1201 |
|
1210 | 1202 |
|
1211 | | -def _shape_check_impl(shape): |
| 1203 | +def _unwrap_shape(shape): |
1212 | 1204 | shape = _constexpr_to_value(shape) |
1213 | | - for i, d in enumerate(shape): |
1214 | | - if isinstance(d, int): |
1215 | | - d = constexpr(d) |
1216 | | - if not isinstance(d, constexpr): |
1217 | | - raise TypeError(f"Shape element {i} must have type `constexpr`") |
1218 | | - if not isinstance(d.value, int): |
1219 | | - raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") |
1220 | | - if d.value & (d.value - 1) != 0: |
1221 | | - raise ValueError(f"Shape element {i} must be a power of 2") |
1222 | | - return [_constexpr_to_value(x) for x in shape] |
| 1205 | + return [_constexpr_to_value(s) for s in shape] |
| 1206 | + |
| 1207 | + |
| 1208 | +def _shape_check_impl(shape): |
| 1209 | + shape = _unwrap_shape(shape) |
| 1210 | + validate_block_shape(shape) |
| 1211 | + return shape |
1223 | 1212 |
|
1224 | 1213 |
|
1225 | 1214 | @builtin |
|
0 commit comments