Is there a way to subclass jnp.ndarray/jax.Array to add a custom field #12901
Replies: 2 comments 1 reply
-
This isn't possible in JAX. This is very much by design: JAX arrays need to get swapped out with Tracers and other non-array-objects internally, when JIT-ing/differentiating/vmap'ing etc. Fortunately it's still pretty easy to carry additional metadata around, just by wrapping your array into a pytree. For this example I'll use Equinox for easy handling of pytrees, but if you prefer it's possible to do something analogous by registering your own custom pytrees. import equinox as eqx
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from typing import Any
class HasMetadata(eqx.Module):
array: jnp.ndarray
metadata: Any
def is_metadata(x):
return isinstance(x, HasMetadata)
def resolve_metadata(x):
if is_metadata(x):
# Or if you wish, do some other processing based on the value of the metadata.
return x.array
else:
return x
linear = eqx.nn.Linear(...)
meta_weight = HasMetdata(linear.weight, {"your": "metadata"})
linear2 = eqx.tree_at(lambda x: x.weight, linear, meta_weight)
# linear2 has metadata annotated on its weight
@jax.jit
@jax.grad
def loss(model, x, y):
model = jtu.tree_map(resolve_metadata, model, is_leaf=is_metdata)
pred_y = model(x)
return jnp.mean((y - pred_y)**2)
loss(linear2, ...) You can also use the same trick for e.g. creating linear layers with symmetric weight matrices: make the "resolution" step return |
Beta Was this translation helpful? Give feedback.
-
Thanks for the answer Patrick! We are aware of this approach and is indeed what we used in our initial implementation. The problem is that we have a lot of code that does things like We've thought about this a lot and the only solution we see is to be able to subclass an array to add some extra fields. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
We have a general purpose codebase which we would like to extend in a way that allows tagging our parameters (jnp.arrays) with metadata. A minimally intrusive solution both for implementation and typing purposes would be a solution that simply extends jnp.array object with an additional field. This is supported in NumPy and they provide detailed documentation on how to do so [link]. I wonder if there is a similar possibility in JAX?
Alternatively, is there a way to create an object that for all intents and purposes in JAX behaves exactly like a jnp.ndarray? That would solve the implementation part of things and I'm sure that we could do some typing magic to make this object look like a jax.Array.
Beta Was this translation helpful? Give feedback.
All reactions