Skip to content

Commit e1e37f8

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] FWD compatibility needs to keep previous version at least one month.
PiperOrigin-RevId: 744796256
1 parent fcf5115 commit e1e37f8

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,22 @@
6464
)
6565

6666

67-
# This tracks the latest Mosaic IR version with a monthly delay.
68-
FWD_COMPAT_IR_VERSION = 4
69-
DEFAULT_IR_VERSION = None
70-
# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date.
71-
if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple(
72-
jax.lib.__version__
73-
) < (0, 5, 4):
74-
FWD_COMPAT_IR_VERSION = 3
75-
DEFAULT_IR_VERSION = 3
67+
# Controls the IR serialization version. Upon incrementing the
68+
# default version in jaxlib/mosaic/dialect/tpu/transforms/serde.cc we must
69+
# continue to use the old serialization version when in forward compatibility
70+
# mode: for 1 month when exporting, or when using old cloud TPU.
71+
#
72+
# This can be achieved by adding:
73+
# if ctx.is_forward_compat() or is_cloud_tpu_older_than(<today>):
74+
# return <previous_serialization_version>
75+
# return None
76+
#
77+
# We should also add a TODO to remove the conditional one month later.
78+
def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None:
79+
# TODO(jevinjiang): remove the forward compatibility check after 2025-05-05.
80+
if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 4, 5):
81+
return 3
82+
return None
7683

7784

7885
tpu_custom_call_p = core.Primitive("tpu_custom_call")
@@ -679,9 +686,7 @@ def lower_module_to_custom_call(
679686
serialization_format=serialization_format,
680687
output_memory_spaces=output_memory_spaces,
681688
kernel_name=kernel_name,
682-
ir_version=FWD_COMPAT_IR_VERSION
683-
if ctx.is_forward_compat()
684-
else DEFAULT_IR_VERSION,
689+
ir_version=get_ir_version(ctx),
685690
)
686691
return _tpu_custom_call_lowering(
687692
ctx,

0 commit comments

Comments
 (0)