@@ -159,7 +159,7 @@ new shape:
159159It is possible to convert dimension expressions explicitly
160160to JAX arrays, with ` jnp.array(x.shape[0]) ` or even ` jnp.array(x.shape) ` .
161161The result of these operations can be used as regular JAX arrays,
162- bug cannot be used anymore as dimensions in shapes.
162+ but cannot be used anymore as dimensions in shapes, e.g., in ` reshape ` :
163163
164164``` python
165165>> > exp = export.export(jax.jit(lambda x : jnp.array(x.shape[0 ]) + x))(
@@ -616,45 +616,6 @@ Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-ass
616616These errors arise in a pre- processing step before the
617617compilation.
618618
619- # ## Division of symbolic dimensions is partially supported
620-
621- JAX will attempt to simplify division and modulo operations,
622- e.g., `(a * b + a) // (b + 1 ) == a` and `(6 * a + 4 ) % 3 == 1 ` .
623- In particular, JAX will handle the cases when either (a) there
624- is no remainder, or (b) the divisor is a constant
625- in which case there may be a constant remainder.
626-
627- For example, the code below results in a division error when trying to
628- compute the inferred dimension for a `reshape` operation:
629-
630- ```python
631- >> > b, = export.symbolic_shape(" b" )
632- >> > export.export(jax.jit(lambda x : x.reshape((2 , - 1 ))))(
633- ... jax.ShapeDtypeStruct((b,), dtype = np.int32))
634- Traceback (most recent call last):
635- jax._src.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (b,) and (2 , - 1 ).
636- The remainder mod(b, - 2 ) should be 0 .
637-
638- ```
639-
640- Note that the following will succeed:
641-
642- ```python
643- >> > b, = export.symbolic_shape(" b" )
644- >> > # We specify that the first dimension is a multiple of 4
645- >> > exp = export.export(jax.jit(lambda x : x.reshape((2 , - 1 ))))(
646- ... jax.ShapeDtypeStruct((4 * b,), dtype = np.int32))
647- >> > exp.out_avals
648- (ShapedArray(int32[2 ,2 * b]),)
649-
650- >> > # We specify that some other dimension is even
651- >> > exp = export.export(jax.jit(lambda x : x.reshape((2 , - 1 ))))(
652- ... jax.ShapeDtypeStruct((b, 5 , 6 ), dtype = np.int32))
653- >> > exp.out_avals
654- (ShapedArray(int32[2 ,15 * b]),)
655-
656- ```
657-
658619(shape_poly_debugging)=
659620# # Debugging
660621
0 commit comments