Jitting jax.lax.pad raises TracerIntegerConversionError/ConcretizationTypeError #11535
-
Is there any way to make this function jit compile?My goal is a function that takes an array as input and pads it so it can be reshaped into any target shape. It raises an error while jit compiling. I have tried to illustrate the portion of the function that raises an error. import jax
import jax.numpy as jnp
s = (2, 7)
source_array = jnp.arange(np.prod(s), dtype = jnp.float32).reshape(s) + 1
# @partial(jax.jit, static_argnums=(1,))
@jax.jit
def pad_array(source_array, total_pad):
pad_right = total_pad // 2
pad_left = jnp.where(total_pad % 2, pad_right + 1, pad_right).item()
padding_config = ((0,0,0),(pad_left, pad_right, 0))
return jax.lax.pad(source_array, padding_value=0., padding_config = padding_config)
print(pad_array(source_array, 5)) pad_array runs without raising errors before jitting. If I try to jit compile it, it raises -> UnfilteredStackTrace Traceback (most recent call last)
[<ipython-input-376-73e723ef7883>](https://localhost:8080/#) in <module>()
8
----> 9 print(pad_array(source_array, 5))
UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
While tracing the function pad_array at <ipython-input-376-73e723ef7883>:1 for jit, this value became a tracer due to JAX operations on these lines:
operation a:i32[] = xla_call[
call_jaxpr={ lambda ; b:i32[] c:i32[] d:i32[]. let
e:bool[] = ne b 0
f:i32[] = select_n e d c
in (f,) }
name=_where
] g h i
from line <ipython-input-376-73e723ef7883>:5 (pad_array)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ConcretizationTypeError Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/jax/_src/device_array.py](https://localhost:8080/#) in item(self)
224 return float(self)
225 elif dtypes.issubdtype(self.dtype, np.integer):
--> 226 return int(self)
227 elif dtypes.issubdtype(self.dtype, np.bool_):
228 return bool(self)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `int` function. If trying to convert the data type of a value, try using `x.astype(int)` or `jnp.array(x, int)` instead.
While tracing the function pad_array at <ipython-input-376-73e723ef7883>:1 for jit, this value became a tracer due to JAX operations on these lines:
operation a:i32[] = xla_call[
call_jaxpr={ lambda ; b:i32[] c:i32[] d:i32[]. let
e:bool[] = ne b 0
f:i32[] = select_n e d c
in (f,) }
name=_where
] g h i
from line <ipython-input-376-73e723ef7883>:5 (pad_array)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError I have also tried -> def pad_odd(a, p):
padding_config = ((0,0,0),(p + 1, p, 0))
return jax.lax.pad(a, padding_value=0., padding_config = padding_config)
def pad_even(a, p):
padding_config = ((0,0,0),(p, p, 0))
return jax.lax.pad(a, padding_value=0., padding_config = padding_config)
def pad_array(source_array, total_pad):
pad_per_side = total_pad // 2
return jax.lax.cond(total_pad % 2,
pad_odd,
pad_even,
source_array, pad_per_side)
print(pad_array(source_array, 5)) But it raises a similar error -> ---------------------------------------------------------------------------
TracerIntegerConversionError Traceback (most recent call last)
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in _dim_handler_and_canonical(*dlist)
1607 try:
-> 1608 canonical.append(operator.index(d))
1609 except TypeError:
32 frames
TracerIntegerConversionError: The __index__() method was called on the JAX Tracer object Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
During handling of the above exception, another exception occurred:
UnfilteredStackTrace Traceback (most recent call last)
UnfilteredStackTrace: TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 7).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last)
[<ipython-input-378-c03608473727>](https://localhost:8080/#) in pad_odd(a, p)
6 def pad_odd(a, p):
7 padding_config = ((0,0,0),(p + 1, p, 0))
----> 8 return jax.lax.pad(a, padding_value=0., padding_config = padding_config)
9
10 def pad_even(a, p):
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 7).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The thing to keep in mind is that in JAX's JIT, array shapes must be static. That means that the shape of an output array cannot depend on a traced value (for more on traced values, see How To Think In JAX. In your function, The only way around this is to ensure that the shape of the output arrays does not depend on traced values; one way to do this is to mark the argument as static, using With all that considered, something like this is probably what you want: import jax
import jax.numpy as jnp
import numpy as np
from functools import partial
s = (2, 7)
source_array = jnp.arange(np.prod(s), dtype = jnp.float32).reshape(s) + 1
@partial(jax.jit, static_argnums=(1,))
def pad_array(source_array, total_pad):
pad_right = total_pad // 2
pad_left = np.where(total_pad % 2, pad_right + 1, pad_right) # Note: np rather than jnp when processing static values
padding_config = ((0,0,0),(pad_left, pad_right, 0))
return jax.lax.pad(source_array, padding_value=0., padding_config = padding_config)
print(pad_array(source_array, 5)) |
Beta Was this translation helpful? Give feedback.
The thing to keep in mind is that in JAX's JIT, array shapes must be static. That means that the shape of an output array cannot depend on a traced value (for more on traced values, see How To Think In JAX. In your function,
total_pad
is a traced value, and the shape of the output array depends ontotal_pad
, so it cannot be JIT-compiled.The only way around this is to ensure that the shape of the output arrays does not depend on traced values; one way to do this is to mark the argument as static, using
@partial(jax.jit, static_argnums=1)
as you have in the comment above your function, and further to not perform any jax operations (likejnp.where
) on the value you would like to be static; …