Skip to content

Commit d4ca359

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Give more time for the changes to propagate to libtpu
PiperOrigin-RevId: 873987035
1 parent a58b181 commit d4ca359

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,11 @@
6565
# We should also add a TODO to remove the conditional one month later.
6666
def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None:
6767
backend = ctx.module_context.get_backend(optional=True)
68-
# TODO(apaszke): remove the forward compatibility check after 2025-3-21.
68+
# TODO(apaszke): remove the forward compatibility check after 2025-4-1.
6969
if (
7070
ctx.is_forward_compat()
7171
or backend is None
72-
or is_cloud_tpu_older_than(2026, 2, 21, backend)
72+
or is_cloud_tpu_older_than(2026, 3, 1, backend)
7373
):
7474
return 9
7575
return None

tests/pallas/ops_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def test_elementwise(self, fn, dtype):
11271127
if (
11281128
fn in (jnp.sin, jnp.cos, jnp.tan)
11291129
and dtype == "bfloat16"
1130-
and not jtu.is_cloud_tpu_at_least(2026, 2, 21)
1130+
and not jtu.is_cloud_tpu_at_least(2026, 3, 1)
11311131
):
11321132
self.skipTest("requires a newer libTPU")
11331133
# TODO(b/370578663): implement these lowerings on TPU

tests/pallas/tpu_ops_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,9 +791,9 @@ def test_pack_elementwise(self, config, shape):
791791
if not jtu.is_device_tpu_at_least(version=5):
792792
self.skipTest("Requires TPU v5+")
793793
if packed_dtype == jnp.int2:
794-
if not jtu.is_cloud_tpu_at_least(2026, 2, 21):
794+
if not jtu.is_cloud_tpu_at_least(2026, 3, 1):
795795
raise self.skipTest(
796-
"int2 is only supported for tpu at least 02/21/2026"
796+
"int2 is only supported for tpu at least 03/01/2026"
797797
)
798798
if (shape[-2] % (8 * 16)) or (shape[-1] % 128):
799799
raise self.skipTest(

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,9 @@ def kernel(x_ref, o_ref):
10071007
np.testing.assert_array_equal(kernel(x), x + np.arange(self.num_lanes))
10081008

10091009
def test_write_to_transformed_ref(self):
1010+
if not jtu.is_cloud_tpu_at_least(2026, 3, 1):
1011+
self.skipTest("Requires a newer libTPU")
1012+
10101013
x = jnp.arange(2 * self.num_lanes)
10111014

10121015
@self.vector_subcore_kernel(out_shape=x)

0 commit comments

Comments
 (0)