Skip to content

Commit 43695c3

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Update readme, Add badge and cross links,
PiperOrigin-RevId: 466330294
1 parent bd06fac commit 43695c3

File tree

1 file changed

+109
-2
lines changed

1 file changed

+109
-2
lines changed

README.md

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
11
# Dataclass Array
22

3+
[![Unittests](https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-research/visu3d/actions/workflows/pytest_and_autopublish.yml)
4+
[![PyPI version](https://badge.fury.io/py/dataclass_array.svg)](https://badge.fury.io/py/dataclass_array)
5+
36
`DataclassArray` are dataclasses which behave like numpy-like arrays (can be
4-
batched, reshaped, sliced,...), but are compatible with Jax, TensorFlow, and
5-
numpy (with torch support planned).
7+
batched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with
8+
torch support planned).
9+
10+
This reduce boilerplate and improve readability. See the
11+
[motivating examples](#motivating-examples) section bellow.
12+
13+
To view an example of dataclass arrays used in practice, see
14+
[visu3d](https://github.com/google-research/visu3d).
615

716
## Documentation
817

18+
### Definition
19+
920
To create a `dca.DataclassArray`, take a frozen dataclass and:
1021

1122
* Inherit from `dca.DataclassArray`
@@ -23,6 +34,8 @@ class Ray(dca.DataclassArray):
2334
dir: FloatArray['*batch_shape 3']
2435
```
2536

37+
### Usage
38+
2639
Afterwards, the dataclass can be used as a numpy array:
2740

2841
```python
@@ -75,6 +88,100 @@ class MyArray(dca.DataclassArray):
7588
f: np.array
7689
```
7790

91+
### Vectorization
92+
93+
`@dca.vectorize_method` allow your dataclass method to automatically support
94+
batching:
95+
96+
1. Implement method as if `self.shape == ()`
97+
2. Decorate the method with `dca.vectorize_method`
98+
99+
```python
100+
@dataclasses.dataclass(frozen=True)
101+
class Camera(dca.DataclassArray):
102+
K: FloatArray['*batch_shape 4 4']
103+
resolution = tuple[int, int]
104+
105+
@dca.vectorize_method
106+
def rays(self) -> Ray:
107+
# Inside `@dca.vectorize_method` shape is always guarantee to be `()`
108+
assert self.shape == ()
109+
assert self.K.shape == (4, 4)
110+
111+
# Compute the ray as if there was only a single camera
112+
return Ray(pos=..., dir=...)
113+
```
114+
115+
Afterward, we can generate rays for multiple camera batched together:
116+
117+
```python
118+
cams = Camera(K=K) # K.shape == (num_cams, 4, 4)
119+
rays = cams.rays() # Generate the rays for all the cameras
120+
121+
cams.shape == (num_cams,)
122+
rays.shape == (num_cams, h, w)
123+
```
124+
125+
`@dca.vectorize_method` is similar to `jax.vmap` but:
126+
127+
* Only work on `dca.DataclassArray` methods
128+
* Instead of vectorizing a single axis, `@dca.vectorize_method` will vectorize
129+
over `*self.shape` (not just `self.shape[0]`). This is like if `vmap` was
130+
applied to `self.flatten()`
131+
* When multiple arguments, axis with dimension `1` are brodcasted.
132+
133+
For example, with `__matmul__(self, x: T) -> T`:
134+
135+
```python
136+
() @ (*x,) -> (*x,)
137+
(b,) @ (b, *x) -> (b, *x)
138+
(b,) @ (1, *x) -> (b, *x)
139+
(1,) @ (b, *x) -> (b, *x)
140+
(b, h, w) @ (b, h, w, *x) -> (b, h, w, *x)
141+
(1, h, w) @ (b, 1, 1, *x) -> (b, h, w, *x)
142+
(a, *x) @ (b, *x) -> Error: Incompatible a != b
143+
```
144+
145+
To test on Colab, see the `visu3d`
146+
dataclass [Colab tutorial](https://colab.research.google.com/github/google-research/visu3d/blob/main/docs/dataclass.ipynb).
147+
148+
## Motivating examples
149+
150+
`dca.DataclassArray` improve readability by simplifying common patterns:
151+
152+
* Reshaping all fields of a dataclass:
153+
154+
Before (`rays` is simple `dataclass`):
155+
156+
```python
157+
num_rays = math.prod(rays.origins.shape[:-1])
158+
rays = jax.tree_map(lambda r: r.reshape((num_rays, -1)), rays)
159+
```
160+
161+
After (`rays` is `DataclassArray`):
162+
163+
```python
164+
rays = rays.flatten() # (b, h, w) -> (b*h*w,)
165+
```
166+
167+
* Rendering a video:
168+
169+
Before (`cams: list[Camera]`):
170+
171+
```python
172+
img = cams[0].render(scene)
173+
imgs = np.stack([cam.render(scene) for cam in cams[::2]])
174+
imgs = np.stack([cam.render(scene) for cam in cams])
175+
```
176+
177+
After (`cams: Camera` with `cams.shape == (num_cams,)`):
178+
179+
```python
180+
img = cams[0].render(scene) # Render only the first camera (to debug)
181+
imgs = cams[::2].render(scene) # Render 1/2 frames (for quicker iteration)
182+
imgs = cams.render(scene) # Render all cameras at once
183+
```
184+
78185
## Installation
79186

80187
```sh

0 commit comments

Comments
 (0)