Skip to content

Commit 908865f

Browse files
Merge pull request jax-ml#25216 from gnecula:poly_doc
PiperOrigin-RevId: 702117006
2 parents 0134fa8 + b3c405c commit 908865f

File tree

1 file changed

+1
-40
lines changed

1 file changed

+1
-40
lines changed

docs/export/shape_poly.md

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ new shape:
159159
It is possible to convert dimension expressions explicitly
160160
to JAX arrays, with `jnp.array(x.shape[0])` or even `jnp.array(x.shape)`.
161161
The result of these operations can be used as regular JAX arrays,
162-
bug cannot be used anymore as dimensions in shapes.
162+
but cannot be used anymore as dimensions in shapes, e.g., in `reshape`:
163163

164164
```python
165165
>>> exp = export.export(jax.jit(lambda x: jnp.array(x.shape[0]) + x))(
@@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass
616616
These errors arise in a pre-processing step before the
617617
compilation.
618618

619-
### Division of symbolic dimensions is partially supported
620-
621-
JAX will attempt to simplify division and modulo operations,
622-
e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`.
623-
In particular, JAX will handle the cases when either (a) there
624-
is no remainder, or (b) the divisor is a constant
625-
in which case there may be a constant remainder.
626-
627-
For example, the code below results in a division error when trying to
628-
compute the inferred dimension for a `reshape` operation:
629-
630-
```python
631-
>>> b, = export.symbolic_shape("b")
632-
>>> export.export(jax.jit(lambda x: x.reshape((2, -1))))(
633-
... jax.ShapeDtypeStruct((b,), dtype=np.int32))
634-
Traceback (most recent call last):
635-
jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2, -1).
636-
The remainder mod(b, - 2) should be 0.
637-
638-
```
639-
640-
Note that the following will succeed:
641-
642-
```python
643-
>>> b, = export.symbolic_shape("b")
644-
>>> # We specify that the first dimension is a multiple of 4
645-
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
646-
... jax.ShapeDtypeStruct((4*b,), dtype=np.int32))
647-
>>> exp.out_avals
648-
(ShapedArray(int32[2,2*b]),)
649-
650-
>>> # We specify that some other dimension is even
651-
>>> exp = export.export(jax.jit(lambda x: x.reshape((2, -1))))(
652-
... jax.ShapeDtypeStruct((b, 5, 6), dtype=np.int32))
653-
>>> exp.out_avals
654-
(ShapedArray(int32[2,15*b]),)
655-
656-
```
657-
658619
(shape_poly_debugging)=
659620
## Debugging
660621

0 commit comments

Comments
 (0)