Skip to content

Commit 3f5f3e1

Browse files
gneculaGoogle-ML-Automation
authored andcommitted
[export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility.
This is because the underlying Triton IR does not guarantee compatibility. PiperOrigin-RevId: 703127711
1 parent 4a41aa0 commit 3f5f3e1

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
7070
return NaN for negative integer inputs, to match the behavior of SciPy from
7171
https://github.com/scipy/scipy/pull/21827.
7272
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
73+
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
74+
call that we guarantee export stability. This is because this custom call
75+
relies on Triton IR, which is not guaranteed to be stable. If you need
76+
to export code that uses this custom call, you can use the `disabled_checks`
77+
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).
7378

7479
* New Features
7580
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for

jax/_src/export/_export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1005,7 +1005,8 @@ def _check_lowering(lowering) -> None:
10051005
*_CPU_FFI_KERNELS,
10061006
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
10071007
"cu_threefry2x32", "cu_threefry2x32_ffi",
1008-
"__gpu$xla.gpu.triton", # Pallas call on GPU
1008+
# Triton IR does not guarantee stability.
1009+
# "__gpu$xla.gpu.triton",
10091010
# cholesky on CPU
10101011
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
10111012
# eigh on TPU

tests/pallas/export_back_compat_pallas_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ def setUp(self):
4848
self.skipTest("Only works on GPUs with capability >= sm80")
4949
super().setUp()
5050

51-
@unittest.skip("TODO(necula): This test is checking backwards compatibility "
51+
@unittest.skip("This test is checking backwards compatibility "
5252
"of Triton IR, but Triton doesn't promise backwards "
53-
"compatibility for its IR.")
53+
"compatibility for its IR, and we have since removed "
54+
"the corresponding custom call from the guaranteed stable list.")
5455
def test_triton_add_one(self):
5556
def func(x):
5657
def add_one(x_ref, o_ref):

tests/pallas/export_pallas_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
5050
exp = export.export(
5151
add_vectors,
5252
platforms=["tpu", "cuda"],
53+
# The Pallas GPU custom call is not enabled for export by default.
54+
disabled_checks=[
55+
export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton")
56+
]
5357
)(a, a)
5458

5559
if (jtu.device_under_test() == "tpu" or

0 commit comments

Comments
 (0)