Skip to content

Commit e707ede

Browse files
Merge pull request jax-ml#25034 from gnecula:poly_state
PiperOrigin-RevId: 698820458
2 parents 2178ed2 + 0831e2e commit e707ede

File tree

6 files changed

+100
-68
lines changed

6 files changed

+100
-68
lines changed

benchmarks/shape_poly_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import jax
1919
from jax import core
20-
from jax._src.numpy import lax_numpy
2120
from jax import export
2221

2322
jax.config.parse_flags_with_absl()
@@ -76,7 +75,7 @@ def inequalities_slice(state):
7675
while state:
7776
for _ in range(30):
7877
a.scope._clear_caches()
79-
start, _, slice_size = lax_numpy._preprocess_slice(slice(2, a, 4), b)
78+
start, _, slice_size = core.canonicalize_slice(slice(2, a, 4), b)
8079
_ = 0 <= slice_size <= b
8180
_ = start >= 0
8281
_ = start + slice_size <= b

jax/_src/core.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2063,6 +2063,70 @@ def dimension_as_value(d: DimSize):
20632063
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
20642064
return operator.index(d)
20652065

2066+
def canonicalize_slice(
2067+
s: slice,
2068+
axis_size: DimSize
2069+
) -> tuple[DimSize, DimSize, DimSize]:
2070+
"""Computes the start index, step, and size of the slice `x[s]`.
2071+
2072+
This is similar to `s.indices(axis_size)`, except that it returns
2073+
`(start, step, size)`, and it works when the slice and/or the
2074+
`axis_size` are symbolic.
2075+
2076+
See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
2077+
"""
2078+
def convert_to_index(d: DimSize) -> DimSize:
2079+
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
2080+
try:
2081+
return operator.index(d)
2082+
except:
2083+
return d
2084+
2085+
# Must resolve statically if step is {<0, ==0, >0}
2086+
step = convert_to_index(s.step) if s.step is not None else 1
2087+
try:
2088+
if step == 0:
2089+
raise ValueError("slice step cannot be zero")
2090+
step_gt_0 = (step > 0)
2091+
except InconclusiveDimensionOperation as e:
2092+
raise InconclusiveDimensionOperation(
2093+
f"In slice with non-constant elements the step ({step}) must " +
2094+
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
2095+
2096+
def clamp_index(i: DimSize, which: str):
2097+
try:
2098+
i_ge_0 = (i >= 0)
2099+
except InconclusiveDimensionOperation as e:
2100+
raise InconclusiveDimensionOperation(
2101+
f"In slice with non-constant elements the {which} ({i}) must " +
2102+
f"be resolved statically if it is >= 0.\nDetails: {e}")
2103+
if i_ge_0:
2104+
if step_gt_0:
2105+
return min_dim(axis_size, i)
2106+
else:
2107+
return min_dim(axis_size - 1, i)
2108+
else:
2109+
if step_gt_0:
2110+
return max_dim(0, axis_size + i)
2111+
else:
2112+
return max_dim(-1, axis_size + i)
2113+
2114+
if s.start is None:
2115+
start = 0 if step_gt_0 else axis_size - 1
2116+
else:
2117+
start = clamp_index(convert_to_index(s.start), "start")
2118+
2119+
if s.stop is None:
2120+
stop = axis_size if step_gt_0 else -1
2121+
else:
2122+
stop = clamp_index(convert_to_index(s.stop), "stop")
2123+
2124+
gap = step if step_gt_0 else - step
2125+
distance = (stop - start) if step_gt_0 else (start - stop)
2126+
slice_size = max_dim(0, distance + gap - 1) // gap
2127+
return start, step, slice_size
2128+
2129+
20662130
class SomeTracer:
20672131
__slots__ = ()
20682132
def __repr__(self): return "[dynamic]"

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -12116,7 +12116,7 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
1211612116
"arrays within JIT compiled functions).")
1211712117
raise IndexError(msg)
1211812118

12119-
start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
12119+
start, step, slice_size = core.canonicalize_slice(i, x_shape[x_axis])
1212012120
slice_shape.append(slice_size)
1212112121

1212212122
if core.definitely_equal(step, 1):
@@ -12319,65 +12319,6 @@ def _canonicalize_tuple_index(arr_ndim, idx):
1231912319
idx = tuple(idx) + colons
1232012320
return idx
1232112321

12322-
def _preprocess_slice(
12323-
s: slice,
12324-
axis_size: core.DimSize
12325-
) -> tuple[core.DimSize, core.DimSize, core.DimSize]:
12326-
"""Computes the start index, step, and size of the slice `x[s]`."""
12327-
# See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
12328-
# "this is harder to get right than you may think"
12329-
# (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275)
12330-
def convert_to_index(d: DimSize) -> DimSize:
12331-
# Convert np.array and jax.Array to int, leave symbolic dimensions alone
12332-
try:
12333-
return operator.index(d)
12334-
except:
12335-
return d
12336-
12337-
# Must resolve statically if step is {<0, ==0, >0}
12338-
step = convert_to_index(s.step) if s.step is not None else 1
12339-
try:
12340-
if step == 0:
12341-
raise ValueError("slice step cannot be zero")
12342-
step_gt_0 = (step > 0)
12343-
except core.InconclusiveDimensionOperation as e:
12344-
raise core.InconclusiveDimensionOperation(
12345-
f"In slice with non-constant elements the step ({step}) must " +
12346-
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
12347-
12348-
def clamp_index(i: DimSize, which: str):
12349-
try:
12350-
i_ge_0 = (i >= 0)
12351-
except core.InconclusiveDimensionOperation as e:
12352-
raise core.InconclusiveDimensionOperation(
12353-
f"In slice with non-constant elements the {which} ({i}) must " +
12354-
f"be resolved statically if it is >= 0.\nDetails: {e}")
12355-
if i_ge_0:
12356-
if step_gt_0:
12357-
return core.min_dim(axis_size, i)
12358-
else:
12359-
return core.min_dim(axis_size - 1, i)
12360-
else:
12361-
if step_gt_0:
12362-
return core.max_dim(0, axis_size + i)
12363-
else:
12364-
return core.max_dim(-1, axis_size + i)
12365-
12366-
if s.start is None:
12367-
start = 0 if step_gt_0 else axis_size - 1
12368-
else:
12369-
start = clamp_index(convert_to_index(s.start), "start")
12370-
12371-
if s.stop is None:
12372-
stop = axis_size if step_gt_0 else -1
12373-
else:
12374-
stop = clamp_index(convert_to_index(s.stop), "stop")
12375-
12376-
gap = step if step_gt_0 else - step
12377-
distance = (stop - start) if step_gt_0 else (start - stop)
12378-
slice_size = core.max_dim(0, distance + gap - 1) // gap
12379-
return start, step, slice_size
12380-
1238112322

1238212323
@export
1238312324
def blackman(M: int) -> Array:

jax/_src/state/indexing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ def __post_init__(self):
4646

4747
@property
4848
def is_dynamic_start(self):
49-
return not isinstance(self.start, int)
49+
return not core.is_dim(self.start)
5050

5151
@property
5252
def is_dynamic_size(self):
53-
return not isinstance(self.size, int)
53+
return not core.is_dim(self.size)
5454

5555
def tree_flatten(self):
5656
# If `start` is statically known, we treat it as static information
@@ -72,10 +72,10 @@ def tree_unflatten(cls, aux_data, children) -> Slice:
7272

7373
@classmethod
7474
def from_slice(cls, slc: slice, size: int) -> Slice:
75-
start, stop, step = slc.indices(size)
75+
start, step, size = core.canonicalize_slice(slc, size)
7676
if step < 1:
7777
raise ValueError(f"slice must have a step >= 1 (found: {step})")
78-
return cls(start, max((stop - start + step - 1) // step, 0), step)
78+
return cls(start, size, step)
7979

8080

8181
def dslice(

tests/shape_poly_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
from jax._src.export import shape_poly_decision
4949
from jax._src.lax import lax as lax_internal
5050
from jax._src.lax import control_flow as lax_control_flow
51+
from jax._src.state import discharge
52+
from jax._src.state import primitives as ref_primitives
53+
5154
import numpy as np
5255

5356
config.parse_flags_with_absl()
@@ -2062,6 +2065,31 @@ def test_vmap_error(self):
20622065
polymorphic_shapes=["b, ...", "c, ...", None])
20632066

20642067

2068+
@jtu.parameterized_filterable(
2069+
kwargs=[
2070+
dict(slc=slc)
2071+
for slc in [
2072+
slice(None, None, None),
2073+
slice(2, 5),
2074+
]
2075+
])
2076+
def test_stateful(self, slc: slice):
2077+
w, = export.symbolic_shape("w", constraints=["w >= 3"])
2078+
def f(x_ref):
2079+
ones = jnp.ones_like(x_ref)[slc]
2080+
ref_primitives.ref_addupdate(x_ref, slc, ones)
2081+
x1 = ref_primitives.ref_get(x_ref, slc)
2082+
x2 = x1 + ones
2083+
ref_primitives.ref_set(x_ref, slc, x2)
2084+
2085+
exp = export.export(jax.jit(discharge.run_state(f)))(
2086+
jax.ShapeDtypeStruct((w,), dtype=_f32))
2087+
x = np.ones((32,), dtype=_f32)
2088+
expected = np.copy(x)
2089+
expected[slc] = 3.
2090+
self.assertAllClose(exp.call(x), expected)
2091+
2092+
20652093
# List containing either harnesses, or lists of harnesses
20662094
_POLY_SHAPE_TEST_HARNESSES = [
20672095
PolyHarness("add", "",
@@ -3603,7 +3631,7 @@ def test_harness(self, harness: PolyHarness):
36033631
not harness.polymorphic_shapes[0].endswith("...") and
36043632
jtu.test_device_matches(["tpu"])):
36053633
raise unittest.SkipTest(
3606-
"Shape polymorphsim for Eigh and Svd is only supported for batch dimensions on TPU.")
3634+
"Shape polymorphism for Eigh and Svd is only supported for batch dimensions on TPU.")
36073635

36083636
config_flags = harness.override_jax_config_flags
36093637
# Update this here rather than in harness object because vmap_random_gamma is derived

tests/state_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def f(a_ref, b_ref):
752752
lu.wrap_init(f), [scalar_ref_1, scalar_ref_2])
753753

754754
discharged_jaxpr, _ = discharge_state(jaxpr, (), should_discharge=[False, True])
755-
prim_count = lambda p, jaxpr: sum(eqn.primitive == swap_p for eqn in jaxpr.eqns)
755+
prim_count = lambda p, jaxpr: sum(eqn.primitive == p for eqn in jaxpr.eqns)
756756
self.assertEqual(prim_count(swap_p, jaxpr) // 2, prim_count(swap_p, discharged_jaxpr))
757757
self.assertEqual(prim_count(get_p, jaxpr) // 2, prim_count(get_p, discharged_jaxpr))
758758

0 commit comments

Comments
 (0)