|
64 | 64 | ) |
65 | 65 |
|
66 | 66 |
|
67 | | -# This tracks the latest Mosaic IR version with a monthly delay. |
68 | | -FWD_COMPAT_IR_VERSION = 4 |
69 | | -DEFAULT_IR_VERSION = None |
70 | | -# TODO(jevinjiang): Remove this once both jaxlib and libtpu are up to date. |
71 | | -if is_cloud_tpu_older_than(2025, 4, 5) or jax.version._version_as_tuple( |
72 | | - jax.lib.__version__ |
73 | | -) < (0, 5, 4): |
74 | | - FWD_COMPAT_IR_VERSION = 3 |
75 | | - DEFAULT_IR_VERSION = 3 |
| 67 | +# Controls the IR serialization version. Upon incrementing the |
| 68 | +# default version in jaxlib/mosaic/dialect/tpu/transforms/serde.cc we must |
| 69 | +# continue to use the old serialization version when in forward compatibility |
| 70 | +# mode: for 1 month when exporting, or when using old cloud TPU. |
| 71 | +# |
| 72 | +# This can be achieved by adding: |
| 73 | +# if ctx.is_forward_compat() or is_cloud_tpu_older_than(<today>): |
| 74 | +# return <previous_serialization_version> |
| 75 | +# return None |
| 76 | +# |
| 77 | +# We should also add a TODO to remove the conditional one month later. |
| 78 | +def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None: |
| 79 | + # TODO(jevinjiang): remove the forward compatibility check after 2025-05-05. |
| 80 | + if ctx.is_forward_compat() or is_cloud_tpu_older_than(2025, 4, 5): |
| 81 | + return 3 |
| 82 | + return None |
76 | 83 |
|
77 | 84 |
|
78 | 85 | tpu_custom_call_p = core.Primitive("tpu_custom_call") |
@@ -679,9 +686,7 @@ def lower_module_to_custom_call( |
679 | 686 | serialization_format=serialization_format, |
680 | 687 | output_memory_spaces=output_memory_spaces, |
681 | 688 | kernel_name=kernel_name, |
682 | | - ir_version=FWD_COMPAT_IR_VERSION |
683 | | - if ctx.is_forward_compat() |
684 | | - else DEFAULT_IR_VERSION, |
| 689 | + ir_version=get_ir_version(ctx), |
685 | 690 | ) |
686 | 691 | return _tpu_custom_call_lowering( |
687 | 692 | ctx, |
|
0 commit comments