11# Dataclass Array
22
3+ [ ![ Unittests] ( https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml/badge.svg )] ( https://github.com/google-research/visu3d/actions/workflows/pytest_and_autopublish.yml )
4+ [ ![ PyPI version] ( https://badge.fury.io/py/dataclass_array.svg )] ( https://badge.fury.io/py/dataclass_array )
5+
36` DataclassArray ` are dataclasses which behave like numpy-like arrays (can be
4- batched, reshaped, sliced,...), but are compatible with Jax, TensorFlow, and
5- numpy (with torch support planned).
7+ batched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with
8+ torch support planned).
9+
10+ This reduce boilerplate and improve readability. See the
11+ [ motivating examples] ( #motivating-examples ) section bellow.
12+
13+ To view an example of dataclass arrays used in practice, see
14+ [ visu3d] ( https://github.com/google-research/visu3d ) .
615
716## Documentation
817
18+ ### Definition
19+
920To create a ` dca.DataclassArray ` , take a frozen dataclass and:
1021
1122* Inherit from ` dca.DataclassArray `
@@ -23,6 +34,8 @@ class Ray(dca.DataclassArray):
2334 dir : FloatArray[' *batch_shape 3' ]
2435```
2536
37+ ### Usage
38+
2639Afterwards, the dataclass can be used as a numpy array:
2740
2841``` python
@@ -75,6 +88,100 @@ class MyArray(dca.DataclassArray):
7588 f: np.array
7689```
7790
91+ ### Vectorization
92+
93+ ` @dca.vectorize_method ` allow your dataclass method to automatically support
94+ batching:
95+
96+ 1 . Implement method as if ` self.shape == () `
97+ 2 . Decorate the method with ` dca.vectorize_method `
98+
99+ ``` python
100+ @dataclasses.dataclass (frozen = True )
101+ class Camera (dca .DataclassArray ):
102+ K: FloatArray[' *batch_shape 4 4' ]
103+ resolution = tuple[int , int ]
104+
105+ @dca.vectorize_method
106+ def rays (self ) -> Ray:
107+ # Inside `@dca.vectorize_method` shape is always guarantee to be `()`
108+ assert self .shape == ()
109+ assert self .K.shape == (4 , 4 )
110+
111+ # Compute the ray as if there was only a single camera
112+ return Ray(pos = ... , dir = ... )
113+ ```
114+
115+ Afterward, we can generate rays for multiple camera batched together:
116+
117+ ``` python
118+ cams = Camera(K = K) # K.shape == (num_cams, 4, 4)
119+ rays = cams.rays() # Generate the rays for all the cameras
120+
121+ cams.shape == (num_cams,)
122+ rays.shape == (num_cams, h, w)
123+ ```
124+
125+ ` @dca.vectorize_method ` is similar to ` jax.vmap ` but:
126+
127+ * Only work on ` dca.DataclassArray ` methods
128+ * Instead of vectorizing a single axis, ` @dca.vectorize_method ` will vectorize
129+ over ` *self.shape ` (not just ` self.shape[0] ` ). This is like if ` vmap ` was
130+ applied to ` self.flatten() `
131+ * When multiple arguments, axis with dimension ` 1 ` are brodcasted.
132+
133+ For example, with ` __matmul__(self, x: T) -> T ` :
134+
135+ ``` python
136+ () @ (* x,) -> (* x,)
137+ (b,) @ (b, * x) -> (b, * x)
138+ (b,) @ (1 , * x) -> (b, * x)
139+ (1 ,) @ (b, * x) -> (b, * x)
140+ (b, h, w) @ (b, h, w, * x) -> (b, h, w, * x)
141+ (1 , h, w) @ (b, 1 , 1 , * x) -> (b, h, w, * x)
142+ (a, * x) @ (b, * x) -> Error: Incompatible a != b
143+ ```
144+
145+ To test on Colab, see the ` visu3d `
146+ dataclass [ Colab tutorial] ( https://colab.research.google.com/github/google-research/visu3d/blob/main/docs/dataclass.ipynb ) .
147+
148+ ## Motivating examples
149+
150+ ` dca.DataclassArray ` improve readability by simplifying common patterns:
151+
152+ * Reshaping all fields of a dataclass:
153+
154+ Before (` rays ` is simple ` dataclass ` ):
155+
156+ ``` python
157+ num_rays = math.prod(rays.origins.shape[:- 1 ])
158+ rays = jax.tree_map(lambda r : r.reshape((num_rays, - 1 )), rays)
159+ ```
160+
161+ After (`rays` is `DataclassArray` ):
162+
163+ ```python
164+ rays = rays.flatten() # (b, h, w) -> (b*h*w,)
165+ ```
166+
167+ * Rendering a video:
168+
169+ Before (`cams: list[Camera]` ):
170+
171+ ```python
172+ img = cams[0 ].render(scene)
173+ imgs = np.stack([cam.render(scene) for cam in cams[::2 ]])
174+ imgs = np.stack([cam.render(scene) for cam in cams])
175+ ```
176+
177+ After (`cams: Camera` with `cams.shape == (num_cams,)` ):
178+
179+ ```python
180+ img = cams[0 ].render(scene) # Render only the first camera (to debug)
181+ imgs = cams[::2 ].render(scene) # Render 1/2 frames (for quicker iteration)
182+ imgs = cams.render(scene) # Render all cameras at once
183+ ```
184+
78185# # Installation
79186
80187```sh
0 commit comments