Skip to content

Commit 4d2e9e5

Browse files
authored
[FRONTEND] Fix bitcast with constexpr dtype (#5382)
Fixes #5364
1 parent 5700c14 commit 4d2e9e5

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

python/test/unit/language/test_core.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from triton._internal_testing import (
2222
integral_dtypes,
2323
int_dtypes,
24+
str_to_triton_dtype,
2425
uint_dtypes,
2526
float_dtypes,
2627
float_dtypes_with_bfloat16,
@@ -1641,7 +1642,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
16411642
('float32', 'bfloat16', False, 1024),
16421643
('bfloat16', 'float32', False, 1024),
16431644
('float32', 'int32', True, 1024),
1644-
('float32', 'int1', False, 1024),
1645+
('float32', 'bool', False, 1024),
16451646
('int8', 'bfloat16', False, 1024),
16461647
] + [(f'uint{x}', f'int{x}', True, 1024)
16471648
for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024)
@@ -1687,35 +1688,40 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
16871688
# triton kernel
16881689

16891690
@triton.jit
1690-
def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr):
1691+
def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constexpr):
16911692
x_ptr = X + tl.arange(0, SIZE)
16921693
z_ptr = Z + tl.arange(0, SIZE)
16931694
x = tl.load(x_ptr)
16941695

16951696
# Depending on the value of ARG_HASH (a "random" number determined by
16961697
# the test parameters), spell the cast one of three different ways.
1697-
if ARG_HASH % 3 == 0:
1698+
if ARG_HASH % 4 == 0:
16981699
z = x.to(Z.dtype.element_ty, bitcast=BITCAST)
1699-
elif ARG_HASH % 3 == 1:
1700+
elif ARG_HASH % 4 == 1:
17001701
z = x.cast(Z.dtype.element_ty, bitcast=BITCAST)
1701-
else:
1702+
elif ARG_HASH % 4 == 2:
17021703
z = tl.cast(x, Z.dtype.element_ty, bitcast=BITCAST)
1704+
else:
1705+
z = tl.cast(x, TO_TYPE, bitcast=BITCAST)
17031706

17041707
tl.store(z_ptr, z)
17051708

17061709
# "Random" number used inside the kernel to determine how we spell the cast.
17071710
# This way we don't have to increase the number of tests.
17081711
arg_hash = hash((dtype_x, dtype_z, bitcast, size, num_ctas))
17091712

1710-
dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_'
1713+
dtype_z_np = dtype_z if dtype_z != 'bool' else 'bool_'
17111714
# triton result
17121715
if dtype_z.startswith('bfloat'):
17131716
z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device)
17141717
elif dtype_z.startswith('float8'):
17151718
z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z))
17161719
else:
17171720
z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device)
1718-
kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1, num_ctas=num_ctas)
1721+
1722+
dtype_z_tri = str_to_triton_dtype(dtype_z)
1723+
kernel[(1, )](x_tri, z_tri, TO_TYPE=dtype_z_tri, BITCAST=bitcast, SIZE=size, ARG_HASH=arg_hash, num_warps=1,
1724+
num_ctas=num_ctas)
17191725
# torch result
17201726
if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith(
17211727
'float8') or dtype_x.startswith('float8'):

python/triton/_internal_testing.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from numpy.random import RandomState
1111
from typing import Optional, Union
12-
from triton.runtime.jit import TensorWrapper, reinterpret
12+
from triton.runtime.jit import TensorWrapper, reinterpret, type_canonicalisation_dict
1313

1414
int_dtypes = ['int8', 'int16', 'int32', 'int64']
1515
uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64']
@@ -119,6 +119,10 @@ def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torc
119119
return torch.tensor(x, device=device)
120120

121121

122+
def str_to_triton_dtype(x: str) -> tl.dtype:
123+
return tl.str_to_ty(type_canonicalisation_dict[x])
124+
125+
122126
def torch_dtype_name(dtype) -> str:
123127
if isinstance(dtype, triton.language.dtype):
124128
return dtype.name

python/triton/language/core.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,13 +1024,7 @@ def to(self, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcast:
10241024
"""
10251025
Alias for :py:func:`tensor.cast`.
10261026
"""
1027-
# Triton doesn't like core functions calling other core functions, so we
1028-
# just copy-paste the implementation of cast here. It's not too bad.
1029-
dtype = _unwrap_if_constexpr(dtype)
1030-
bitcast = _unwrap_if_constexpr(bitcast)
1031-
if bitcast:
1032-
return semantic.bitcast(self, dtype, _builder)
1033-
return semantic.cast(self, dtype, _builder, fp_downcast_rounding)
1027+
return cast(self, dtype, fp_downcast_rounding, bitcast, _builder=_builder)
10341028

10351029
# Type stubs for functions added by the _tensor_member_fn decorator.
10361030
# (Unfortunately these can't be created automatically.)
@@ -1685,8 +1679,9 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas
16851679
:type bitcast: bool, optional
16861680
"""
16871681
input = semantic.to_tensor(input, _builder)
1688-
if isinstance(bitcast, constexpr):
1689-
bitcast = bitcast.value
1682+
dtype = _constexpr_to_value(dtype)
1683+
fp_downcast_rounding = _constexpr_to_value(fp_downcast_rounding)
1684+
bitcast = _constexpr_to_value(bitcast)
16901685
if bitcast:
16911686
return semantic.bitcast(input, dtype, _builder)
16921687
return semantic.cast(input, dtype, _builder, fp_downcast_rounding)

python/triton/language/semantic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -828,10 +828,6 @@ def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tenso
828828
def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder,
829829
fp_downcast_rounding: Optional[str] = None) -> tl.tensor:
830830
src_ty = input.type
831-
if isinstance(dst_ty, tl.constexpr):
832-
dst_ty = dst_ty.value
833-
if isinstance(fp_downcast_rounding, tl.constexpr):
834-
fp_downcast_rounding = fp_downcast_rounding.value
835831
if src_ty.is_block():
836832
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
837833
if src_ty == dst_ty:

0 commit comments

Comments
 (0)