Skip to content

Commit 7871be2

Browse files
authored
[Fix] fallback in specialize even for native types (#8122)
- should address regression discovered introduced in #7771 - adds unit-test which ideally covers all supported specializations to spot other edge cases sooner (currently not exhaustive for dtypes and descriptor layouts)
1 parent 6fa1dd6 commit 7871be2

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed

python/src/specialize.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,14 @@ std::pair<py::object, py::object> specialize_arg(PyObject *backend,
506506
return handle_tensor(backend, arg, is_const, specialize_value, align);
507507
}
508508

509+
// fallback for default types
510+
if (PyLong_Check(arg)) {
511+
return handle_long_type(backend, arg, is_const, specialize_value, align);
512+
}
513+
if (PyFloat_Check(arg)) {
514+
return handle_float_type(backend, arg, is_const, specialize_value, align);
515+
}
516+
509517
return {};
510518
}
511519

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import numpy
2+
import pytest
3+
import torch
4+
from collections import namedtuple
5+
from triton._C.libtriton import native_specialize_impl
6+
from triton.runtime.jit import MockTensor, JITCallable
7+
from triton._utils import canonicalize_dtype
8+
from triton.backends.nvidia.compiler import CUDABackend
9+
from triton.backends.amd.compiler import HIPBackend
10+
from triton.language import constexpr
11+
from triton.tools.tensor_descriptor import TensorDescriptor
12+
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor as GluonTensorDescriptor
13+
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
14+
15+
16+
def mock_tensor_from_tensor(tensor):
17+
return MockTensor(tensor.dtype, tensor.shape)
18+
19+
20+
class MockJITCallable(JITCallable):
21+
22+
def __init__(self):
23+
pass
24+
25+
def cache_key(self):
26+
return "mock_jit_callable"
27+
28+
29+
class MockFloat(float):
30+
31+
def __new__(cls, value):
32+
return super().__new__(cls, value)
33+
34+
35+
class MockInt(int):
36+
37+
def __new__(cls, value):
38+
return super().__new__(cls, value)
39+
40+
41+
def reference_specialize_impl(backend, arg, is_const, specialize_value, align):
42+
if arg is None:
43+
return ("constexpr", None)
44+
elif isinstance(arg, bool):
45+
return ("u1", None)
46+
elif isinstance(arg, int):
47+
key = backend.get_int_specialization(arg, align=align) if specialize_value else None
48+
if arg == 1 and specialize_value:
49+
return ("constexpr", 1)
50+
elif -(2**31) <= arg and arg <= 2**31 - 1:
51+
return ("i32", key)
52+
elif 2**63 <= arg and arg <= 2**64 - 1:
53+
return ("u64", key)
54+
else:
55+
return ("i64", key)
56+
elif isinstance(arg, float):
57+
return ("fp32", None)
58+
elif hasattr(arg, "data_ptr"):
59+
dsk = (arg.dtype, is_const)
60+
res = ("*k" if dsk[1] else "*") + canonicalize_dtype(dsk[0])
61+
key = backend.get_tensor_specialization(arg, align=align) if specialize_value else None
62+
return (res, key)
63+
elif isinstance(arg, JITCallable):
64+
return ("constexpr", arg.cache_key)
65+
elif isinstance(arg, constexpr):
66+
return ("constexpr", arg)
67+
elif isinstance(arg, tuple):
68+
spec = [reference_specialize_impl(backend, x, False, True, True) for x in arg]
69+
make_tuple = lambda vals: type(arg)(*vals) if hasattr(arg, "_fields") else tuple(vals)
70+
tys = make_tuple([x[0] for x in spec])
71+
keys = make_tuple([x[1] for x in spec])
72+
return (tys, keys)
73+
elif isinstance(arg, TensorDescriptor):
74+
assert hasattr(arg.base, "data_ptr")
75+
inner = canonicalize_dtype(arg.base.dtype)
76+
return (f"tensordesc<{inner}{list(arg.block_shape)}>", None)
77+
elif isinstance(arg, GluonTensorDescriptor):
78+
assert hasattr(arg.base, "data_ptr")
79+
inner = canonicalize_dtype(arg.base.dtype)
80+
return (f"tensordesc<{inner}{list(arg.block_shape)},{arg.layout!r}>", None)
81+
else:
82+
raise TypeError("Unsupported type: %s" % type(arg))
83+
84+
85+
def native_inputs_to_specialize():
86+
return [
87+
1.0,
88+
None,
89+
False,
90+
True,
91+
1,
92+
0,
93+
-1,
94+
16,
95+
17,
96+
2**31 - 1,
97+
2**31,
98+
-2 * 31 - 1,
99+
2**63 - 1,
100+
2**63,
101+
2**63 + 1,
102+
2**64 - 1,
103+
]
104+
105+
106+
def derived_inputs_to_specialize():
107+
return [
108+
constexpr(1),
109+
constexpr(False),
110+
constexpr(1.0),
111+
numpy.float64(1.0),
112+
MockFloat(1.0),
113+
MockInt(1),
114+
MockJITCallable(),
115+
]
116+
117+
118+
def tuples_to_specialize():
119+
return [
120+
(1, 1),
121+
(False, True),
122+
namedtuple('strides', ['x', 'y'])(1, 1),
123+
namedtuple('flags', ['x', 'y'])(False, True),
124+
]
125+
126+
127+
def tensors_to_specialize():
128+
return [
129+
torch.empty(shape, dtype=dtype, device="cpu")
130+
for shape in [(1, ), (1, 1), (16, ), (16, 16), (128, ), (128, 128)]
131+
for dtype in [torch.float64, torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int64]
132+
]
133+
134+
135+
def tensordescriptors_to_specialize():
136+
return [
137+
TensorDescriptor.from_tensor(tensor, block_shape=tensor.shape)
138+
for tensor in tensors_to_specialize()
139+
if tensor.shape[-1] % 16 == 0
140+
]
141+
142+
143+
def gluon_tensordescriptors_to_specialize():
144+
return [
145+
GluonTensorDescriptor.from_tensor(
146+
tensor,
147+
block_shape=tensor.shape,
148+
layout=NVMMASharedLayout(0, tensor.dtype.itemsize * 8, len(tensor.shape)),
149+
) for tensor in tensors_to_specialize() if tensor.shape[-1] % 16 == 0
150+
]
151+
152+
153+
def mock_tensors_to_specialize():
154+
return [mock_tensor_from_tensor(tensor) for tensor in tensors_to_specialize()]
155+
156+
157+
@pytest.mark.parametrize("input_generator", [
158+
native_inputs_to_specialize,
159+
tuples_to_specialize,
160+
tensors_to_specialize,
161+
tensordescriptors_to_specialize,
162+
gluon_tensordescriptors_to_specialize,
163+
mock_tensors_to_specialize,
164+
])
165+
@pytest.mark.parametrize("backend", [CUDABackend, HIPBackend])
166+
@pytest.mark.parametrize("is_const", [True, False])
167+
@pytest.mark.parametrize("specialize_value", [True, False])
168+
@pytest.mark.parametrize("align", [True, False])
169+
def test_specialize_impl(input_generator, backend, is_const, specialize_value, align):
170+
for arg in input_generator():
171+
result = native_specialize_impl(backend, arg, is_const, specialize_value, align)
172+
expected = reference_specialize_impl(backend, arg, is_const, specialize_value, align)
173+
assert result == expected

0 commit comments

Comments
 (0)