Skip to content

Commit ed49283

Browse files
authored
[FRONTEND] Throw an error when we would downcast an integral constant to a dtype it does not fit in (#5866)
1 parent 9e62654 commit ed49283

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

python/test/unit/language/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,18 @@ def kernel(X, SIZE: tl.constexpr):
318318
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
319319

320320

321+
def test_scalar_overflow(device):
322+
323+
@triton.jit
324+
def kernel():
325+
huge_int: tl.constexpr = 0xFFFFFFFFFFFFFF
326+
x = tl.full((), 32, dtype=tl.int32)
327+
y = x + huge_int
328+
329+
with pytest.raises(triton.TritonError, match="out of range"):
330+
kernel[(1, )]()
331+
332+
321333
# generic test functions
322334
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1):
323335
check_type_supported(dtype_x, device) # early return if dtype_x is not supported

python/triton/language/semantic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@ def binary_op_type_checking_impl(lhs: tl.tensor | numbers.Number, rhs: tl.tensor
186186
or rhs_is_scalar and rhs_scalar < 0 and ret_sca_ty.is_int_unsigned()):
187187
raise ValueError("Cannot perform a binary operation between an unsigned tensor and a negative scalar. "
188188
"Perform a explicit cast on one of them.")
189+
if ret_sca_ty.is_int():
190+
if lhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= lhs_scalar <= ret_sca_ty.get_int_max_value()):
191+
raise ValueError(f"Scalar {lhs_scalar} is out of range for type {ret_sca_ty}")
192+
if rhs_is_scalar and not (ret_sca_ty.get_int_min_value() <= rhs_scalar <= ret_sca_ty.get_int_max_value()):
193+
raise ValueError(f"Scalar {rhs_scalar} is out of range for type {ret_sca_ty}")
189194
lhs = full(
190195
(), lhs_scalar, dtype=ret_sca_ty, builder=builder) if lhs_is_scalar else cast(lhs, ret_sca_ty, builder)
191196
rhs = full(

0 commit comments

Comments
 (0)