Skip to content

Commit 5fe5206

Browse files
gneculaGoogle-ML-Automation
authored andcommitted
[shape_poly] Remove some deprecated kwargs
PiperOrigin-RevId: 703116755
1 parent e510295 commit 5fe5206

File tree

2 files changed

+3
-23
lines changed

2 files changed

+3
-23
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
5353
use `uses_global_constants`.
5454
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
5555
`platforms` instead.
56+
* The kwargs `symbolic_scope` and `symbolic_constraints` from
57+
{func}`jax.export.symbolic_args_specs` have been removed. They were
58+
deprecated in June 2024. Use `scope` and `constraints` instead.
5659
* Hashing of tracers, which has been deprecated since version 0.4.30, now
5760
results in a `TypeError`.
5861
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and

jax/_src/export/shape_poly.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,12 +1198,6 @@ def is_symbolic_dim(p: DimSize) -> bool:
11981198
"""
11991199
return isinstance(p, _DimExpr)
12001200

1201-
def is_poly_dim(p: DimSize) -> bool:
1202-
# TODO: deprecated January 2024, remove June 2024.
1203-
warnings.warn("is_poly_dim is deprecated, use export.is_symbolic_dim",
1204-
DeprecationWarning, stacklevel=2)
1205-
return is_symbolic_dim(p)
1206-
12071201
dtypes.python_scalar_dtypes[_DimExpr] = dtypes.python_scalar_dtypes[int]
12081202

12091203
def _einsum_contract_path(*operands, **kwargs):
@@ -1413,8 +1407,6 @@ def symbolic_args_specs(
14131407
shapes_specs, # prefix pytree of strings
14141408
constraints: Sequence[str] = (),
14151409
scope: SymbolicScope | None = None,
1416-
symbolic_constraints: Sequence[str] = (), # DEPRECATED on 6/14/24
1417-
symbolic_scope: SymbolicScope | None = None, # DEPRECATED on 6/14/24
14181410
):
14191411
"""Constructs a pytree of jax.ShapeDtypeSpec arguments specs for `export`.
14201412
@@ -1435,25 +1427,10 @@ def symbolic_args_specs(
14351427
arguments](https://jax.readthedocs.io/en/latest/pytrees.html#applying-optional-parameters-to-pytrees).
14361428
constraints: as for :func:`jax.export.symbolic_shape`.
14371429
scope: as for :func:`jax.export.symbolic_shape`.
1438-
symbolic_constraints: DEPRECATED, use `constraints`.
1439-
symbolic_scope: DEPRECATED, use `scope`.
14401430
14411431
Returns: a pytree of jax.ShapeDTypeStruct matching the `args` with the shapes
14421432
replaced with symbolic dimensions as specified by `shapes_specs`.
14431433
"""
1444-
if symbolic_constraints:
1445-
warnings.warn("symbolic_constraints is deprecated, use constraints",
1446-
DeprecationWarning, stacklevel=2)
1447-
if constraints:
1448-
raise ValueError("Cannot use both symbolic_constraints and constraints")
1449-
constraints = symbolic_constraints
1450-
if symbolic_scope is not None:
1451-
warnings.warn("symbolic_scope is deprecated, use scope",
1452-
DeprecationWarning, stacklevel=2)
1453-
if scope is not None:
1454-
raise ValueError("Cannot use both symbolic_scope and scope")
1455-
scope = symbolic_scope
1456-
14571434
polymorphic_shapes = shapes_specs
14581435
args_flat, args_tree = tree_util.tree_flatten(args)
14591436

0 commit comments

Comments
 (0)