2525import pytest
2626
2727
28+ _DS = dca .field_utils .DataclassWithShape
2829Ray = dca .testing .Ray
2930
3031
@@ -39,7 +40,10 @@ class Camera(dca.DataclassArray):
3940 [
4041 (int , [int ]),
4142 (Ray , [Ray ]),
43+ (Ray ['h w' ], [Ray ['h w' ]]),
44+ (Ray [..., 3 ], [Ray [..., 3 ]]),
4245 (Union [Ray , int ], [Ray , int ]),
46+ (Union [Ray ['h w' ], int ], [Ray ['h w' ], int ]),
4347 (Union [Ray , int , None ], [Ray , int , None ]),
4448 (Optional [Ray ], [Ray , None ]),
4549 (Optional [Union [Ray , int ]], [Ray , int , None ]),
@@ -55,10 +59,12 @@ def test_get_leaf_types(hint, expected):
5559 'hint, expected' ,
5660 [
5761 (int , None ),
58- (Ray , Ray ),
59- (Optional [Ray ], Ray ),
60- (Union [Ray , Camera ], dca .DataclassArray ),
61- (Union [Ray , Camera , None ], dca .DataclassArray ),
62+ (Ray , _DS (Ray , '...' )),
63+ (Ray ['h w' ], _DS (Ray , 'h w' )),
64+ (Ray [..., 3 ], _DS (Ray , '... 3' )),
65+ (Optional [Ray ], _DS (Ray , '...' )),
66+ (Union [Ray , Camera ], _DS (dca .DataclassArray , '...' )),
67+ (Union [Ray , Camera , None ], _DS (dca .DataclassArray , '...' )),
6268 (Union [Ray , int ], None ),
6369 (Union [Ray , int , None ], None ),
6470 (Union [f32 [3 , 3 ], int , None ], None ),
@@ -72,9 +78,51 @@ def test_get_array_type(hint, expected):
7278 assert type_parsing .get_array_type (hint ) == expected
7379
7480
81+ @pytest .mark .parametrize (
82+ 'hint, expected' ,
83+ [
84+ (Ray , _DS (Ray , '...' )),
85+ (Ray ['h w' ], _DS (Ray , 'h w' )),
86+ (Ray [..., 3 ], _DS (Ray , '... 3' )),
87+ ],
88+ )
89+ def test_from_hint (hint , expected ):
90+ assert dca .field_utils .DataclassWithShape .from_hint (hint ) == expected
91+
92+
7593def test_get_array_type_error ():
7694 with pytest .raises (NotImplementedError ):
7795 type_parsing .get_array_type (Union [Ray , f32 [3 , 3 ]])
7896
7997 with pytest .raises (NotImplementedError ):
8098 type_parsing .get_array_type (Union [FloatArray [..., 3 ], f32 [3 , 3 ]])
99+
100+
101+ @pytest .mark .parametrize (
102+ 'hint, expected' ,
103+ [
104+ (
105+ Ray ,
106+ dca .array_dataclass ._ArrayFieldMetadata (
107+ inner_shape_non_static = (),
108+ dtype = Ray ,
109+ ),
110+ ),
111+ (
112+ Ray [..., 3 ],
113+ dca .array_dataclass ._ArrayFieldMetadata (
114+ inner_shape_non_static = (3 ,),
115+ dtype = Ray ,
116+ ),
117+ ),
118+ (
119+ Ray ['*shape 4 _' ],
120+ dca .array_dataclass ._ArrayFieldMetadata (
121+ inner_shape_non_static = (4 , None ),
122+ dtype = Ray ,
123+ ),
124+ ),
125+ ],
126+ )
127+ def test_type_to_field_metadata (hint , expected ):
128+ assert dca .array_dataclass ._type_to_field_metadata (hint ) == expected
0 commit comments