You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix JIT specialization data (de)serialization for tuples and constexprs (triton-lang#8639)
As part of using the `triton.knobs.runtime.jit_cache_hook`, the
`JITFunction` class performs JSON serialization on the specialization
data. The serialized specialization data is then expected to be used as
part of the `preload()` function, where it will be deserialized and used
to compile the Triton kernel.
However, this process fails to account for the following cases:
- When part of the Triton python signature is a Python tuple, the
serialization process will transform it into a list (Because JSON
serializes tuples as lists); the deserialization process does not
transform it back into a tuple, leading to a parsing failure when
`ast_to_ttir()` is invoked.
- When the constants contain a `tl.constexpr` value, the serialization
process will raise an error, because `tl.constexpr` is not serializable.
This PR addresses both of these issues by:
- Applying the reverse transformation in the deserialization from lists
to tuples for signatures. We can do this unconditionally because lists
are not accepted as part of the signature of a Triton kernel.
- Adding a special case for `constexpr` for constants in the
specialization data, so that it can be serialized and deserialized
without losing its type.
- Adding a test that is the exact same as
`test_passing_nested_tuple_with_constexpr`, but with the JIT hook setup
so that we can verify that the serialization/deserialization round-trip
works as intended.
0 commit comments