|
1 | 1 | # Dataclass Array |
2 | 2 |
|
| 3 | +`DataclassArray` are dataclasses which behave like numpy-like arrays (can be |
| 4 | +batched, reshaped, sliced,...), but are compatible with Jax, TensorFlow, and |
| 5 | +numpy (with torch support planned). |
| 6 | + |
| 7 | +## Documentation |
| 8 | + |
| 9 | +To create a `dca.DataclassArray`, take a frozen dataclass and: |
| 10 | + |
| 11 | +* Inherit from `dca.DataclassArray` |
| 12 | +* Annotate the fields with `etils.array_types` to specify the inner shape and |
| 13 | + dtype of the array (see below for static or nested dataclass fields). |
| 14 | + |
| 15 | +```python |
| 16 | +import dataclass_array as dca |
| 17 | +from etils.array_types import FloatArray |
| 18 | + |
| 19 | + |
| 20 | +@dataclasses.dataclass(frozen=True) |
| 21 | +class Ray(dca.DataclassArray): |
| 22 | + pos: FloatArray['*batch_shape 3'] |
| 23 | + dir: FloatArray['*batch_shape 3'] |
| 24 | +``` |
| 25 | + |
| 26 | +Afterwards, the dataclass can be used as a numpy array: |
| 27 | + |
| 28 | +```python |
| 29 | +ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3)) |
| 30 | + |
| 31 | + |
| 32 | +ray.shape == (3,) # 3 rays batched together |
| 33 | +ray.pos.shape == (3, 3) # Individual fields still available |
| 34 | + |
| 35 | +# Numpy slicing/indexing/masking |
| 36 | +ray = ray[..., 1:2] |
| 37 | +ray = ray[norm(ray.dir) > 1e-7] |
| 38 | + |
| 39 | +# Shape transformation |
| 40 | +ray = ray.reshape((1, 3)) |
| 41 | +ray = ray.reshape('h w -> w h') # Native einops support |
| 42 | +ray = ray.flatten() |
| 43 | + |
| 44 | +# Stack multiple dataclass arrays together |
| 45 | +ray = dca.stack([ray0, ray1, ...]) |
| 46 | + |
| 47 | +# Supports TF, Jax, Numpy (torch planned) and can be easily converted |
| 48 | +ray = ray.as_jax() # as_np(), as_tf() |
| 49 | +ray.xnp == jax.numpy # `numpy`, `jax.numpy`, `tf.experimental.numpy` |
| 50 | + |
| 51 | +# Compatibility `with jax.tree_util`, `jax.vmap`,.. |
| 52 | +ray = jax.tree_util.tree_map(lambda x: x+1, ray) |
| 53 | +``` |
| 54 | + |
| 55 | +A `DataclassArray` has 2 types of fields: |
| 56 | + |
| 57 | +* Array fields: Fields batched like numpy arrays, with reshape, slicing,... |
| 58 | + Can be `xnp.ndarray` or nested `dca.DataclassArray`. |
| 59 | +* Static fields: Other non-numpy field. Are not modified by reshaping,... |
| 60 | + Static fields are also ignored in `jax.tree_map`. |
| 61 | + |
| 62 | +```python |
| 63 | +@dataclasses.dataclass(frozen=True) |
| 64 | +class MyArray(dca.DataclassArray): |
| 65 | + # Array fields |
| 66 | + a: FloatArray['*batch_shape 3'] # Defined by `etils.array_types` |
| 67 | + b: Ray # Nested DataclassArray (inner shape == `()`) |
| 68 | + |
| 69 | + # Array fields explicitly defined |
| 70 | + c: Any = dca.field(shape=(3,), dtype=np.float32) |
| 71 | + d: Ray = dca.field(shape=(3,), dtype=Ray) # Nested DataclassArray |
| 72 | + |
| 73 | + # Static field (everything not defined as above) |
| 74 | + e: float |
| 75 | + f: np.array |
| 76 | +``` |
| 77 | + |
| 78 | +## Installation |
| 79 | + |
| 80 | +```sh |
| 81 | +pip install dataclass_array |
| 82 | +``` |
| 83 | + |
3 | 84 | *This is not an official Google product* |
0 commit comments