Skip to content

Commit 68c1905

Browse files
PeterZhizhinThe dataclass_array Authors
authored andcommitted
Add dca.concatenate to dataclass_array
Allows dataclass arrays to be not only stacked, but also concatenated together. PiperOrigin-RevId: 546279663
1 parent e267037 commit 68c1905

File tree

4 files changed

+87
-12
lines changed

4 files changed

+87
-12
lines changed

CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
## [1.4.3] - 2023-07-07
27+
28+
* Add `dca.concat` method in addition to `dca.stack`.
29+
2630
## [1.4.2] - 2023-06-29
2731

2832
* Now require Python 3.9 (drop 3.8 support)
@@ -69,8 +73,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
6973

7074
* Initial release
7175

72-
[Unreleased]: https://github.com/google-research/dataclass_array/compare/v1.4.2...HEAD
73-
[1.4.1]: https://github.com/google-research/dataclass_array/compare/v1.4.1...v1.4.2
76+
[Unreleased]: https://github.com/google-research/dataclass_array/compare/v1.4.3...HEAD
77+
[1.4.3]: https://github.com/google-research/dataclass_array/compare/v1.4.2...v1.4.3
78+
[1.4.2]: https://github.com/google-research/dataclass_array/compare/v1.4.1...v1.4.2
7479
[1.4.1]: https://github.com/google-research/dataclass_array/compare/v1.4.0...v1.4.1
7580
[1.4.0]: https://github.com/google-research/dataclass_array/compare/v1.3.0...v1.4.0
7681
[1.3.0]: https://github.com/google-research/dataclass_array/compare/v1.2.1...v1.3.0

dataclass_array/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from dataclass_array.array_dataclass import array_field as field
3535
from dataclass_array.array_dataclass import dataclass_array
3636
from dataclass_array.array_dataclass import DataclassArray
37+
from dataclass_array.ops import concat
3738
from dataclass_array.ops import stack
3839
from dataclass_array.vectorization import vectorize_method
3940

dataclass_array/array_dataclass_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,23 @@ class PointDynamicShape(dca.DataclassArray):
734734
)
735735

736736

737+
@enp.testing.parametrize_xnp()
738+
@pytest.mark.parametrize('batch_shape', [(1,), (3,)])
739+
def test_concatenate(xnp: enp.NpModule, batch_shape: Shape):
740+
class TestConcatenateClass(dca.DataclassArray):
741+
x: FloatArray['*shape 3']
742+
y: FloatArray['*shape']
743+
744+
p = TestConcatenateClass(
745+
x=xnp.zeros(batch_shape + (3,), dtype=xnp.float32),
746+
y=xnp.zeros(batch_shape, dtype=xnp.float32),
747+
)
748+
749+
p_concatenated = dca.concat([p, p, p])
750+
assert p_concatenated.x.shape == tuple(x * 3 for x in batch_shape) + (3,)
751+
assert p_concatenated.y.shape == tuple(x * 3 for x in batch_shape)
752+
753+
737754
def test_class_getitem():
738755
assert Point == Point # pylint: disable=comparison-with-itself
739756
assert Point[''] == Point['']

dataclass_array/ops.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,40 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import Iterable # pylint: disable=g-multiple-import
19+
import functools
20+
from typing import Any, Callable, Iterable, Optional # pylint: disable=g-multiple-import
2021

2122
from dataclass_array import array_dataclass
22-
from dataclass_array.typing import DcT
23+
from dataclass_array.typing import Array, DcT # pylint: disable=g-multiple-import
2324
from dataclass_array.utils import np_utils
25+
from etils import enp
2426
from etils import epy
2527

2628

27-
def stack(
28-
arrays: Iterable[DcT], # list[_DcT['*shape']]
29+
def _ops_base(
30+
arrays: Iterable[DcT],
2931
*,
30-
axis: int = 0,
31-
) -> DcT: # _DcT['len(arrays) *shape']:
32-
"""Stack dataclasses together."""
32+
axis: int,
33+
array_fn: Callable[
34+
[
35+
enp.NpModule,
36+
int,
37+
Any, # array_dataclass._ArrayField[Array['*din']],
38+
],
39+
Array['*dout'],
40+
],
41+
dc_fn: Optional[
42+
Callable[
43+
[
44+
enp.NpModule,
45+
int,
46+
Any, # array_dataclass._ArrayField[DcT],
47+
],
48+
DcT,
49+
]
50+
],
51+
) -> DcT:
52+
"""Base function for all ops."""
3353
arrays = list(arrays)
3454
first_arr = arrays[0]
3555

@@ -61,9 +81,41 @@ def stack(
6181
# jax.tree_map(lambda x, y: x+y, (None, 10), (1, 2)) == (None, 12)
6282
# Similarly, static values will be the ones from the first element.
6383
merged_arr = first_arr._map_field( # pylint: disable=protected-access
64-
array_fn=lambda f: xnp.stack( # pylint: disable=g-long-lambda
84+
array_fn=functools.partial(array_fn, xnp, axis),
85+
dc_fn=functools.partial(dc_fn, xnp, axis),
86+
)
87+
return merged_arr
88+
89+
90+
def stack(
91+
arrays: Iterable[DcT], # list[_DcT['*shape']]
92+
*,
93+
axis: int = 0,
94+
) -> DcT: # _DcT['len(arrays) *shape']:
95+
"""Stack dataclasses together."""
96+
return _ops_base(
97+
arrays,
98+
axis=axis,
99+
array_fn=lambda xnp, axis, f: xnp.stack( # pylint: disable=g-long-lambda
65100
[getattr(arr, f.name) for arr in arrays], axis=axis
66101
),
67-
dc_fn=lambda f: stack([getattr(arr, f.name) for arr in arrays]),
102+
dc_fn=lambda xnp, axis, f: stack( # pylint: disable=g-long-lambda
103+
[getattr(arr, f.name) for arr in arrays],
104+
axis=axis,
105+
),
106+
)
107+
108+
109+
def concat(arrays: Iterable[DcT], *, axis: int = 0) -> DcT:
110+
"""Concatenate dataclasses together."""
111+
return _ops_base(
112+
arrays,
113+
axis=axis,
114+
array_fn=lambda xnp, axis, f: xnp.concatenate( # pylint: disable=g-long-lambda
115+
[getattr(arr, f.name) for arr in arrays], axis=axis
116+
),
117+
dc_fn=lambda xnp, axis, f: concat( # pylint: disable=g-long-lambda
118+
[getattr(arr, f.name) for arr in arrays],
119+
axis=axis,
120+
),
68121
)
69-
return merged_arr

0 commit comments

Comments
 (0)