Skip to content

Commit 362fb7a

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
Remove code to support jaxlib < 0.5.3.
The new xla_extension_version is 320. PiperOrigin-RevId: 738522486
1 parent 4489303 commit 362fb7a

File tree

10 files changed

+16
-110
lines changed

10 files changed

+16
-110
lines changed

jax/_src/export/_export.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from jax._src.interpreters import mlir
4444
from jax._src.interpreters import pxla
4545
from jax._src.lib import xla_client
46-
from jax._src.lib import xla_extension, xla_extension_version
46+
from jax._src.lib import xla_extension
4747
from jax._src.lib.mlir import ir, passmanager
4848
from jax._src.lib.mlir.dialects import hlo
4949
from jax._src.lib.mlir.dialects import func as func_dialect
@@ -674,10 +674,8 @@ def _export_lowered(
674674
# Shardy was used during lowering if we can find the Shardy mesh in the
675675
# module. Note that the mesh should have been lifted by the
676676
# `sdy-lift-inlined-meshes` pass in mlir.py.
677-
shardy_enabled = False
678-
if xla_extension_version >= 319:
679-
shardy_enabled = xla_extension.sdy.lowered_with_shardy(
680-
mlir.module_to_bytecode(mlir_module))
677+
shardy_enabled = xla_extension.sdy.lowered_with_shardy(
678+
mlir.module_to_bytecode(mlir_module))
681679

682680
mlir_module_serialized = _module_to_bytecode(mlir_module, shardy_enabled)
683681

@@ -784,7 +782,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
784782
_get_vjp=_get_exported_vjp)
785783

786784
def _module_to_bytecode(module: ir.Module, shardy_enabled: bool) -> bytes:
787-
if xla_extension_version >= 319 and shardy_enabled:
785+
if shardy_enabled:
788786
mlir_str = xla_extension.sdy.sdy_round_trip_export_pipeline(
789787
mlir.module_to_bytecode(module))
790788
else:
@@ -1423,10 +1421,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
14231421
ctx.module_context.shape_poly_state.uses_dim_vars = True
14241422
submodule = ir.Module.parse(exported.mlir_module())
14251423

1426-
shardy_enabled = False
1427-
if xla_extension_version >= 319:
1428-
shardy_enabled = xla_extension.sdy.lowered_with_shardy(
1429-
mlir.module_to_bytecode(submodule))
1424+
shardy_enabled = xla_extension.sdy.lowered_with_shardy(
1425+
mlir.module_to_bytecode(submodule))
14301426
if shardy_enabled:
14311427
submodule = ir.Module.parse(xla_extension.sdy.sdy_round_trip_import_shardings(
14321428
mlir.module_to_bytecode(submodule)))

jax/_src/interpreters/mlir.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
SdyArraySharding, SdyArrayShardingList)
5656
from jax._src.util import foreach
5757
from jax._src.lib import xla_client as xc
58-
from jax._src.lib import xla_extension, xla_extension_version
58+
from jax._src.lib import xla_extension
5959
from jax._src.lib.mlir import dialects, ir, passmanager
6060
from jax._src.lib.mlir.dialects import func as func_dialect, hlo
6161
from jax._src.lib.mlir import register_jax_dialects
@@ -3031,11 +3031,8 @@ def refine_polymorphic_shapes(module: ir.Module) -> ir.Module:
30313031
mlir_module=module_to_bytecode(module),
30323032
enable_shape_assertions=True,
30333033
validate_static_shapes=True)
3034-
if xla_extension_version >= 319:
3035-
refined_module_str = refine_polymorphic_shapes(
3036-
enable_shardy=config.use_shardy_partitioner.value)
3037-
else:
3038-
refined_module_str = refine_polymorphic_shapes()
3034+
refined_module_str = refine_polymorphic_shapes(
3035+
enable_shardy=config.use_shardy_partitioner.value)
30393036
except Exception as e:
30403037
raise ValueError(
30413038
"Error refining shapes. " +

jax/_src/lax/lax.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@
6666
from jax._src.lib.mlir import ir
6767
from jax._src.lib.mlir.dialects import chlo
6868
from jax._src.lib.mlir.dialects import hlo
69-
from jax._src.lib import xla_extension_version
7069
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
7170
PartitionSpec as P, canonicalize_sharding)
7271
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
@@ -2267,11 +2266,6 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
22672266
case DotAlgorithmPreset.BF16_BF16_F32_X6:
22682267
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False)
22692268
case DotAlgorithmPreset.BF16_BF16_F32_X9:
2270-
if xla_extension_version < 320:
2271-
raise ValueError(
2272-
"The dot algorithm BF16_BF16_F32_X9 requires XLA extension "
2273-
"version >= 320."
2274-
)
22752269
return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 9, False)
22762270
case DotAlgorithmPreset.TF32_TF32_F32:
22772271
return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False)

jax/_src/sharding_impls.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from jax._src import xla_bridge as xb
3434
from jax._src import mesh_utils
3535
from jax._src.lib import xla_client as xc
36-
from jax._src.lib import xla_extension_version
3736
from jax._src.lib.mlir.dialects import sdy
3837
from jax._src.named_sharding import ( # noqa: F401
3938
SdyArraySharding, SdyDimSharding, UnspecifiedValue, AUTO,
@@ -881,8 +880,7 @@ def parse_flatten_op_sharding(
881880
return out
882881
elif hlo_sharding.is_replicated():
883882
return [PartitionSpec()]
884-
elif (xla_extension_version >= 319 and hlo_sharding.is_maximal()
885-
and mesh.size == 1):
883+
elif hlo_sharding.is_maximal() and mesh.size == 1:
886884
return [PartitionSpec()]
887885
elif hlo_sharding.is_tiled():
888886
mesh_shape = mesh.shape

jax/_src/util.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,7 @@ def foreach(f, *args):
108108
return None
109109

110110
else:
111-
# TODO(phawkins): remove after jaxlib 0.5.2 is the minimum.
112-
if hasattr(jaxlib_utils, 'foreach'):
113-
foreach = jaxlib_utils.foreach
114-
else:
115-
foreach = safe_map
111+
foreach = jaxlib_utils.foreach
116112

117113

118114
def unzip2(xys: Iterable[tuple[T1, T2]]
@@ -244,61 +240,8 @@ def curry(f):
244240
"""
245241
return wraps(f)(partial(partial, f))
246242

247-
# TODO(phawkins): make this unconditional after jaxlib 0.5.3 is the minimum.
248243
toposort: Callable[[Iterable[Any]], list[Any]]
249-
if hasattr(jaxlib_utils, "topological_sort"):
250-
toposort = partial(jaxlib_utils.topological_sort, "parents")
251-
else:
252-
253-
def toposort(end_nodes):
254-
if not end_nodes:
255-
return []
256-
end_nodes = _remove_duplicates(end_nodes)
257-
258-
child_counts = {}
259-
stack = list(end_nodes)
260-
while stack:
261-
node = stack.pop()
262-
if id(node) in child_counts:
263-
child_counts[id(node)] += 1
264-
else:
265-
child_counts[id(node)] = 1
266-
stack.extend(node.parents)
267-
for node in end_nodes:
268-
child_counts[id(node)] -= 1
269-
270-
sorted_nodes = []
271-
childless_nodes = [
272-
node for node in end_nodes if child_counts[id(node)] == 0
273-
]
274-
assert childless_nodes
275-
while childless_nodes:
276-
node = childless_nodes.pop()
277-
sorted_nodes.append(node)
278-
for parent in node.parents:
279-
if child_counts[id(parent)] == 1:
280-
childless_nodes.append(parent)
281-
else:
282-
child_counts[id(parent)] -= 1
283-
sorted_nodes = sorted_nodes[::-1]
284-
285-
check_toposort(sorted_nodes)
286-
return sorted_nodes
287-
288-
def check_toposort(nodes):
289-
visited = set()
290-
for node in nodes:
291-
assert all(id(parent) in visited for parent in node.parents)
292-
visited.add(id(node))
293-
294-
def _remove_duplicates(node_list):
295-
seen = set()
296-
out = []
297-
for n in node_list:
298-
if id(n) not in seen:
299-
seen.add(id(n))
300-
out.append(n)
301-
return out
244+
toposort = partial(jaxlib_utils.topological_sort, "parents")
302245

303246

304247
def split_merge(predicate, xs):
@@ -320,7 +263,6 @@ def merge(new_lhs, new_rhs):
320263

321264
return lhs, rhs, merge
322265

323-
324266
def _ignore(): return None
325267

326268

jax/experimental/mosaic/gpu/dialect_lowering.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -849,13 +849,9 @@ def _mgpu_wait_op_lowering_rule(
849849
return []
850850

851851

852-
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
853-
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
854-
855-
856-
@_register_lowering(SliceSMEMOp)
852+
@_register_lowering(mgpu.SliceSMEMOp)
857853
def _mgpu_slice_smem_op_lowering_rule(
858-
ctx: LoweringContext, op: SliceSMEMOp
854+
ctx: LoweringContext, op: mgpu.SliceSMEMOp
859855
) -> Sequence[ir.Value]:
860856
del ctx
861857
return [_slice_smem(op.result.type, op.offset)]

jax/experimental/mosaic/gpu/transform_inference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,9 @@ def _infer_vector_load_store_transforms(
172172

173173
return None
174174

175-
# TODO(bchetioui): remove this once jaxlib minimum version >= 0.5.2.
176-
SliceSMEMOp = getattr(mgpu, "SliceSMEMOp", None)
177175

178-
@partial(_add_transform_inference_rule, SliceSMEMOp)
179-
def _infer_slice_smem_transforms(op: SliceSMEMOp) -> OptionalTransforms:
176+
@partial(_add_transform_inference_rule, mgpu.SliceSMEMOp)
177+
def _infer_slice_smem_transforms(op: mgpu.SliceSMEMOp) -> OptionalTransforms:
180178
transforms = None
181179
uses = cast(ir.OpResult, op.result).uses
182180

jax/experimental/sparse/linalg.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from jax._src import core
3030
from jax._src import ffi
3131
from jax._src.interpreters import ad
32-
from jax._src.lib import gpu_solver
3332

3433
import numpy as np
3534
from scipy.sparse import csr_matrix, linalg
@@ -534,11 +533,6 @@ def _spsolve_abstract_eval(data, indices, indptr, b, *, tol, reorder):
534533

535534

536535
def _spsolve_gpu_lowering(ctx, data, indices, indptr, b, *, tol, reorder):
537-
# TODO(danfm): remove after JAX 0.5.1 release.
538-
if hasattr(gpu_solver, "cuda_csrlsvqr"):
539-
data_aval, _, _, _, = ctx.avals_in
540-
return gpu_solver.cuda_csrlsvqr(data_aval.dtype, data, indices,
541-
indptr, b, tol, reorder)
542536
return ffi.ffi_lowering("cusolver_csrlsvqr_ffi")(
543537
ctx, data, indices, indptr, b, tol=np.float64(tol),
544538
reorder=np.int32(reorder))

tests/lax_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from jax._src.lax import lax as lax_internal
5050
from jax._src.util import NumpyComplexWarning, safe_zip
5151
from jax._src.tree_util import tree_map
52-
from jax._src.lib import xla_extension_version
5352

5453
config.parse_flags_with_absl()
5554

@@ -1128,11 +1127,6 @@ def testDotAlgorithm(self, algorithm, dtype):
11281127
raise SkipTest(
11291128
f"The dot algorithm '{algorithm}' is not supported on CPU.")
11301129
if jtu.test_device_matches(["gpu"]):
1131-
if (algorithm == lax.DotAlgorithmPreset.BF16_BF16_F32_X9 and
1132-
xla_extension_version < 320):
1133-
raise SkipTest(
1134-
f"The dot algorithm ${algorithm} requires XLA extension version "
1135-
">= 320.")
11361130
# GPU algorithm support is a little spotty. It is checked in
11371131
# xla/service/algorithm_util.cc and the logic is copied here.
11381132
if algorithm in {

tests/linalg_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,9 +867,6 @@ def testSVD(self, b, m, n, dtype, full_matrices, compute_uv, hermitian, algorith
867867
self.skipTest("Hermitian SVD doesn't support the algorithm parameter.")
868868
if not jtu.test_device_matches(["cpu", "gpu"]):
869869
self.skipTest("SVD algorithm selection only supported on CPU and GPU.")
870-
# TODO(danfm): Remove this check after 0.5.2 is released.
871-
if jtu.test_device_matches(["cpu"]) and jtu.jaxlib_version() <= (0, 5, 1):
872-
self.skipTest("SVD algorithm selection on CPU requires a newer jaxlib version.")
873870
if jtu.test_device_matches(["cpu"]) and algorithm == lax.linalg.SvdAlgorithm.JACOBI:
874871
self.skipTest("Jacobi SVD not supported on GPU.")
875872

0 commit comments

Comments
 (0)