Skip to content

Commit 4b1d843

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Accept named dim in DataclassArray FloatArray['... h w 3'] (without consistency checking)
PiperOrigin-RevId: 476835750
1 parent 7624853 commit 4b1d843

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2525

2626
* Changed: By default, dataclass_array do not cast and broadcast inputs
2727
anymore.
28+
* Changed: `dca.DataclassArray` fields can be annotated with named axis (e.g.
29+
`FloatArray['*shape h w 3']`). Note that consistency across fields is not
30+
checked yet.
2831
* Added: `@dca.dataclass_array` to customize the `dca.DataclassArray` params
2932

3033
## [1.1.0] - 2022-08-15

dataclass_array/shape_parsing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,16 @@ def get_inner_shape(shape_str: str) -> _Shape:
8888
shape = parser.parse(shape_str)
8989

9090
# TODO(epot): Reraise typing with `shape` debug message
91-
# TODO(epot): Support `_` & `None` dim
9291
if not shape or not isinstance(shape[0], _VarDim):
9392
raise ValueError(
9493
f'Shape {shape_str!r} should start by `...` or `*shape` (e.g. '
9594
'`f32[\'*shape 3\']`)')
9695

9796
inner_shape = shape[1:]
97+
# Currently, `_NamedDim` are accepted but consistency isn't checked across
98+
# fields
99+
inner_shape = tuple(
100+
None if isinstance(s, _NamedDim) else s for s in inner_shape)
98101
if not all(isinstance(dim, (int, type(None))) for dim in inner_shape):
99102
raise ValueError('Only static or None dimensions supported.')
100103

dataclass_array/shape_parsing_test.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def test_parse_shape_types():
6767
('... 3 5', (3, 5)),
6868
('... 3 5 7', (3, 5, 7)),
6969
('... 3 _ _ 7', (3, None, None, 7)),
70+
('... 3 h w 7', (3, None, None, 7)),
7071
('*shape', ()),
7172
('*shape _', (None,)),
7273
('*shape 3', (3,)),
@@ -92,11 +93,7 @@ def test_get_inner_shape_failure_first_dim(shape_str: str):
9293

9394
@pytest.mark.parametrize(
9495
'shape_str',
95-
[
96-
'... ...',
97-
'... 3 d 1',
98-
'*shape 3 d 1',
99-
],
96+
['... ...'],
10097
)
10198
def test_get_inner_shape_failure_dynamic(shape_str: str):
10299
with pytest.raises(ValueError, match='Only static or None dimension'):

0 commit comments

Comments
 (0)