Skip to content

Commit cb6881d

Browse files
beckerheGoogle-ML-Automation
authored andcommitted
Reverts bdadc53
PiperOrigin-RevId: 704758075
1 parent e6dfe8f commit cb6881d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/aot_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
from jax._src import core
2121
from jax._src import test_util as jtu
22+
import jax._src.lib
2223
from jax._src.lib import xla_client as xc
2324
from jax.experimental import topologies
2425
from jax.experimental.pjit import pjit
@@ -62,7 +63,11 @@ def verify_serialization(lowered):
6263
jax.pmap(lambda x: x * x).lower(
6364
np.zeros((len(jax.devices()), 4), dtype=np.float32)))
6465

65-
@jtu.skip_on_devices('gpu') # Test fails in CI
66+
@unittest.skipIf(
67+
jax._src.lib.xla_extension_version < 300,
68+
'AOT compiler registration was broken in XLA extension version below'
69+
' 300.',
70+
)
6671
def test_topology_pjit_serialize(self):
6772
try:
6873
aot_topo = topologies.get_topology_desc(

0 commit comments

Comments
 (0)