Skip to content

Commit 1a9734c

Browse files
authored
Fix JIT specialization data (de)serialization for tuples and constexprs (triton-lang#8639)
As part of using the `triton.knobs.runtime.jit_cache_hook`, the `JITFunction` class performs JSON serialization on the specialization data. The serialized specialization data is then expected to be used as part of the `preload()` function, where it will be deserialized and used to compile the Triton kernel. However, this process fails to account for the following cases: - When part of the Triton python signature is a Python tuple, the serialization process will transform it into a list (Because JSON serializes tuples as lists); the deserialization process does not transform it back into a tuple, leading to a parsing failure when `ast_to_ttir()` is invoked. - When the constants contain a `tl.constexpr` value, the serialization process will raise an error, because `tl.constexpr` is not serializable. This PR addresses both of these issues by: - Applying the reverse transformation in the deserialization from lists to tuples for signatures. We can do this unconditionally because lists are not accepted as part of the signature of a Triton kernel. - Adding a special case for `constexpr` for constants in the specialization data, so that it can be serialized and deserialized without losing its type. - Adding a test that is the exact same as `test_passing_nested_tuple_with_constexpr`, but with the JIT hook setup so that we can verify that the serialization/deserialization round-trip works as intended.
1 parent 1070cd5 commit 1a9734c

File tree

2 files changed

+61
-12
lines changed

2 files changed

+61
-12
lines changed

python/test/unit/language/test_tuple.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -256,17 +256,45 @@ def m_to_the_n(X, shape: tl.constexpr, strides, m_n):
256256
torch.testing.assert_close(x, expected_x, rtol=0, atol=0)
257257

258258

259+
@triton.jit
260+
def _nested_tuple_kernel(x):
261+
# This creates a new scope, which will force a copy of liveins. It's
262+
# important for this to happen as it forces IR flattening/unflattening,
263+
# which relies on the types being correct for the roundtrip to succeed.
264+
for _ in range(1):
265+
tl.static_assert(x[1][0] == 2)
266+
267+
259268
def test_passing_nested_tuple_with_constexpr(device):
269+
_nested_tuple_kernel[(1, )](((1, ), (tl.constexpr(2), )))
260270

261-
@triton.jit
262-
def test(x):
263-
# This creates a new scope, which will force a copy of liveins. It's
264-
# important for this to happen as it forces IR flattening/unflattening,
265-
# which relies on the types being correct for the roundtrip to succeed.
266-
for _ in range(1):
267-
tl.static_assert(x[1][0] == 2)
268-
269-
test[(1, )](((1, ), (tl.constexpr(2), )))
271+
272+
def test_passing_nested_tuple_with_constexpr_and_jit_hook(device, fresh_knobs):
273+
# get the serialized specialization data
274+
specialization_data = None
275+
276+
def cache_hook(*args, **kwargs):
277+
nonlocal specialization_data
278+
specialization_data = kwargs["compile"]["specialization_data"]
279+
280+
fresh_knobs.runtime.jit_cache_hook = cache_hook
281+
282+
device = getattr(torch, device).current_device()
283+
284+
# Clear the existing cache for this device to ensure that the hook is called;
285+
# This is needed because the kernel is shared between multiple tests and may
286+
# already have been compiled for this device.
287+
_nested_tuple_kernel.device_caches[device][0].clear()
288+
289+
warmup_run = _nested_tuple_kernel.warmup(((1, ), (tl.constexpr(2), )), grid=(1, ))
290+
assert warmup_run is not None
291+
292+
assert specialization_data is not None
293+
294+
preload_run = _nested_tuple_kernel.preload(specialization_data)
295+
assert preload_run is not None
296+
297+
assert warmup_run.hash == preload_run.hash
270298

271299

272300
def test_passing_tuple_to_make_tensor_descriptor(device, with_allocator):

python/triton/runtime/jit.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,12 @@ def __getitem__(self, grid) -> T:
366366

367367

368368
def serialize_specialization_data(name, signature, constants, attrs, options, key):
369-
constants = {key: str(value) if value.__class__.__name__ == "dtype" else value for key, value in constants.items()}
369+
constants = {
370+
key: str(value) if value.__class__.__name__ == "dtype" else
371+
{"constexpr": value.value} if value.__class__.__name__ == "constexpr" else value
372+
for key, value in constants.items()
373+
}
374+
370375
import json
371376
obj = {
372377
'name': name, 'signature': signature, 'constant_keys': [list(x) for x in constants.keys()], 'constant_vals':
@@ -557,6 +562,18 @@ def compute_cache_key(kernel_key_cache, specialization, options):
557562
return cache_key
558563

559564

565+
def convert_to_tuple_if_list(item):
566+
# If the incoming item is a list, recursively iterate through it to convert all lists therein into tuples
567+
if not isinstance(item, list):
568+
return item
569+
570+
# The value must be a list at this point
571+
for i, nested_value in enumerate(item):
572+
item[i] = convert_to_tuple_if_list(nested_value)
573+
574+
return tuple(item)
575+
576+
560577
class JITFunction(JITCallable, KernelInterface[T]):
561578

562579
def is_gluon(self):
@@ -759,13 +776,17 @@ def preload(self, specialization_data):
759776
constant_keys = map(tuple, deserialized_obj['constant_keys'])
760777
constant_vals = deserialized_obj['constant_vals']
761778
constexprs = {
762-
key: tl.dtype(value) if tl.dtype.is_dtype(value) else value
779+
key:
780+
tl.dtype(value) if tl.dtype.is_dtype(value) else
781+
tl.constexpr(value['constexpr']) if isinstance(value, dict) and 'constexpr' in value else value
763782
for key, value in zip(constant_keys, constant_vals)
764783
}
765784
attrs_keys = map(tuple, deserialized_obj['attrs_keys'])
766785
attrs_vals = deserialized_obj['attrs_vals']
767786
attrs = dict(zip(attrs_keys, attrs_vals))
768-
signature = dict(deserialized_obj['signature'].items())
787+
# JSON serializes tuples as lists, so they need to be converted back;
788+
# This can be done unconditionally, since lists are not accepted in Triton kernel signatures.
789+
signature = {key: convert_to_tuple_if_list(value) for key, value in deserialized_obj['signature'].items()}
769790
options = {
770791
key: tuple(value) if isinstance(value, list) else value
771792
for key, value in deserialized_obj['options'].items()

0 commit comments

Comments
 (0)