We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e6dfe8f commit cb6881dCopy full SHA for cb6881d
tests/aot_test.py
@@ -19,6 +19,7 @@
19
import jax
20
from jax._src import core
21
from jax._src import test_util as jtu
22
+import jax._src.lib
23
from jax._src.lib import xla_client as xc
24
from jax.experimental import topologies
25
from jax.experimental.pjit import pjit
@@ -62,7 +63,11 @@ def verify_serialization(lowered):
62
63
jax.pmap(lambda x: x * x).lower(
64
np.zeros((len(jax.devices()), 4), dtype=np.float32)))
65
- @jtu.skip_on_devices('gpu') # Test fails in CI
66
+ @unittest.skipIf(
67
+ jax._src.lib.xla_extension_version < 300,
68
+ 'AOT compiler registration was broken in XLA extension version below'
69
+ ' 300.',
70
+ )
71
def test_topology_pjit_serialize(self):
72
try:
73
aot_topo = topologies.get_topology_desc(
0 commit comments