Is there danger in converting 1x1 TrackedArrays into integers using base int? #11851
-
For context, I'm coming from Julia where converting values during autodiff means you lose lots of information contained within the previous type of the value. I know, however, that JAX works in subtly different ways to Julia sometimes, so I would appreciate knowing whether this is frowned upon or accepted. The use of this is that some functions require inputs to be of a certain type and will not accept TrackedArrays of any sort. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
In general, converting concrete traced values to non-traced Python scalars or numpy arrays will effectively drop those values from any automatic differentiation. We can see this by comparing the output of a simple function to the same function in which we make use of jax internals to cast the input to a Python float. Here's a simple auto-diffed function: def f(x):
return 2 * x
print(jax.grad(f)(1.0))
# 2.0 And here's what happens if we cast the traced input to a Python float (which requires using some JAX internals, as it's not recommended): def f(x):
x_float = float(jax.core.get_aval(x).val)
return 2.0 * x_float
print(jax.grad(f)(1.0))
# 0.0 The result is incorrect, because we've unbundled the tracer and subverted the autodiff framework. To avoid this kind of issue, JAX does not give users any easy path to do this; for example, attempting to cast a concrete traced array to float will result in an error: def f(x):
x_float = float(x) # Conversion not allowed because it's unsafe.
return 2.0 * x_float
jax.grad(f)(1.0)
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ConcreteArray(1.0, dtype=float32, weak_type=True)>with<JVPTrace(level=2/0)> with
# The problem arose with the `float` function. If trying to convert the data type of a value, try using `x.astype(float)` or `jnp.array(x, float)` instead. Integers are a different case though: they already side-step autodiff because there's no way to perturb an integer, so differentiating with respect to one will return a special symbolic zero with dtype def f(x):
return 2.0 * x
print(repr(jax.grad(f, allow_int=True)(1)))
# array((b'',), dtype=[('float0', 'V')])
def f(x):
x_int = int(x) # Conversion allowed because it's safe
return 2.0 * x_int
print(repr(jax.grad(f, allow_int=True)(1)))
# array((b'',), dtype=[('float0', 'V')]) So to your question: for integer values within autodiff, you can either keep the traced value or cast it to an untraced value. Both options are generally safe within autodiff, because integers don't contribute to gradients in either case. Note that JAX itself uses this in several places, for example in specializing the implementation of |
Beta Was this translation helpful? Give feedback.
In general, converting concrete traced values to non-traced Python scalars or numpy arrays will effectively drop those values from any automatic differentiation. We can see this by comparing the output of a simple function to the same function in which we make use of jax internals to cast the input to a Python float.
Here's a simple auto-diffed function:
And here's what happens if we cast the traced input to a Python float (which requires using some JAX internals, as it's not recommended):
The result is incorrect, because we've u…