Skip to content

Commit bd06fac

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Release dataclass array v1.0.0
PiperOrigin-RevId: 466078556
1 parent 26ad684 commit bd06fac

File tree

3 files changed

+87
-4
lines changed

3 files changed

+87
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26-
## [0.1.0] - 2022-01-01
26+
## [1.0.0] - 2022-08-08
2727

28-
* Initial release
28+
* Initial release
2929

30-
[Unreleased]: https://github.com/google-research/dataclass_array/compare/v0.1.0...HEAD
30+
31+
[Unreleased]: https://github.com/google-research/dataclass_array/compare/v1.0.0...HEAD
32+
[1.0.0]: https://github.com/google-research/dataclass_array/compare/v0.1.0...v1.0.0
3133
[0.1.0]: https://github.com/google-research/dataclass_array/releases/tag/v0.1.0

README.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,84 @@
11
# Dataclass Array
22

3+
`DataclassArray` are dataclasses which behave like numpy-like arrays (can be
4+
batched, reshaped, sliced,...), but are compatible with Jax, TensorFlow, and
5+
numpy (with torch support planned).
6+
7+
## Documentation
8+
9+
To create a `dca.DataclassArray`, take a frozen dataclass and:
10+
11+
* Inherit from `dca.DataclassArray`
12+
* Annotate the fields with `etils.array_types` to specify the inner shape and
13+
dtype of the array (see below for static or nested dataclass fields).
14+
15+
```python
16+
import dataclass_array as dca
17+
from etils.array_types import FloatArray
18+
19+
20+
@dataclasses.dataclass(frozen=True)
21+
class Ray(dca.DataclassArray):
22+
pos: FloatArray['*batch_shape 3']
23+
dir: FloatArray['*batch_shape 3']
24+
```
25+
26+
Afterwards, the dataclass can be used as a numpy array:
27+
28+
```python
29+
ray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))
30+
31+
32+
ray.shape == (3,) # 3 rays batched together
33+
ray.pos.shape == (3, 3) # Individual fields still available
34+
35+
# Numpy slicing/indexing/masking
36+
ray = ray[..., 1:2]
37+
ray = ray[norm(ray.dir) > 1e-7]
38+
39+
# Shape transformation
40+
ray = ray.reshape((1, 3))
41+
ray = ray.reshape('h w -> w h') # Native einops support
42+
ray = ray.flatten()
43+
44+
# Stack multiple dataclass arrays together
45+
ray = dca.stack([ray0, ray1, ...])
46+
47+
# Supports TF, Jax, Numpy (torch planned) and can be easily converted
48+
ray = ray.as_jax() # as_np(), as_tf()
49+
ray.xnp == jax.numpy # `numpy`, `jax.numpy`, `tf.experimental.numpy`
50+
51+
# Compatibility `with jax.tree_util`, `jax.vmap`,..
52+
ray = jax.tree_util.tree_map(lambda x: x+1, ray)
53+
```
54+
55+
A `DataclassArray` has 2 types of fields:
56+
57+
* Array fields: Fields batched like numpy arrays, with reshape, slicing,...
58+
Can be `xnp.ndarray` or nested `dca.DataclassArray`.
59+
* Static fields: Other non-numpy field. Are not modified by reshaping,...
60+
Static fields are also ignored in `jax.tree_map`.
61+
62+
```python
63+
@dataclasses.dataclass(frozen=True)
64+
class MyArray(dca.DataclassArray):
65+
# Array fields
66+
a: FloatArray['*batch_shape 3'] # Defined by `etils.array_types`
67+
b: Ray # Nested DataclassArray (inner shape == `()`)
68+
69+
# Array fields explicitly defined
70+
c: Any = dca.field(shape=(3,), dtype=np.float32)
71+
d: Ray = dca.field(shape=(3,), dtype=Ray) # Nested DataclassArray
72+
73+
# Static field (everything not defined as above)
74+
e: float
75+
f: np.array
76+
```
77+
78+
## Installation
79+
80+
```sh
81+
pip install dataclass_array
82+
```
83+
384
*This is not an official Google product*

dataclass_array/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,6 @@
4343

4444
# A new PyPI release will be pushed everytime `__version__` is increased
4545
# When changing this, also update the CHANGELOG.md
46-
__version__ = '0.1.0'
46+
__version__ = '1.0.0'
4747

4848
del sys, pytest

0 commit comments

Comments
 (0)