Skip to content

Commit 53c2965

Browse files
authored
[Frontend] Factor out block shape validation function (#4915)
1 parent 55c9576 commit 53c2965

File tree

2 files changed

+33
-23
lines changed

2 files changed

+33
-23
lines changed

python/triton/language/_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import List
2+
3+
TRITON_MAX_TENSOR_NUMEL = 1048576
4+
5+
6+
def is_power_of_two(x):
7+
return (x & (x - 1)) == 0
8+
9+
10+
def validate_block_shape(shape: List[int]):
11+
numel = 1
12+
for i, d in enumerate(shape):
13+
if not isinstance(d, int):
14+
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d)}]")
15+
if not is_power_of_two(d):
16+
raise ValueError(f"Shape element {i} must be a power of 2")
17+
numel *= d
18+
19+
if numel > TRITON_MAX_TENSOR_NUMEL:
20+
raise ValueError(f"numel ({numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
21+
return numel

python/triton/language/core.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313

1414
from .._C.libtriton import ir
1515
from . import semantic
16+
from ._utils import TRITON_MAX_TENSOR_NUMEL, validate_block_shape
1617

1718
T = TypeVar('T')
1819

19-
TRITON_MAX_TENSOR_NUMEL = 1048576
20-
2120
TRITON_BUILTIN = "__triton_builtin__"
2221

2322
PropagateNan = ir.PROPAGATE_NAN
@@ -612,18 +611,11 @@ def __init__(self, element_ty: dtype, shape: List):
612611
# while tensor's shape is a list of constexpr.
613612

614613
# shape can be empty ([]) when an input is a 0D tensor.
615-
if not shape:
614+
self.shape = _unwrap_shape(shape)
615+
if not self.shape:
616616
raise TypeError('0d block_type is forbidden')
617-
if isinstance(shape[0], constexpr):
618-
shape = [s.value for s in shape]
619-
620-
self.shape = shape
621-
self.numel = 1
622-
for s in self.shape:
623-
self.numel *= s
624-
if self.numel > TRITON_MAX_TENSOR_NUMEL:
625-
raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})")
626617

618+
self.numel = validate_block_shape(self.shape)
627619
self.name = f'<{self.shape}, {self.element_ty}>'
628620

629621
def to_ir(self, builder: ir.builder) -> ir.block_type:
@@ -1208,18 +1200,15 @@ def arange(start, end, _builder=None):
12081200
"""
12091201

12101202

1211-
def _shape_check_impl(shape):
1203+
def _unwrap_shape(shape):
12121204
shape = _constexpr_to_value(shape)
1213-
for i, d in enumerate(shape):
1214-
if isinstance(d, int):
1215-
d = constexpr(d)
1216-
if not isinstance(d, constexpr):
1217-
raise TypeError(f"Shape element {i} must have type `constexpr`")
1218-
if not isinstance(d.value, int):
1219-
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
1220-
if d.value & (d.value - 1) != 0:
1221-
raise ValueError(f"Shape element {i} must be a power of 2")
1222-
return [_constexpr_to_value(x) for x in shape]
1205+
return [_constexpr_to_value(s) for s in shape]
1206+
1207+
1208+
def _shape_check_impl(shape):
1209+
shape = _unwrap_shape(shape)
1210+
validate_block_shape(shape)
1211+
return shape
12231212

12241213

12251214
@builtin

0 commit comments

Comments
 (0)