@@ -159,7 +159,7 @@ def imag(self):
159159
160160@serialize_container .register (DOFArray )
161161def _serialize_dof_container (ary : DOFArray ):
162- return enumerate (ary .data )
162+ return list ( enumerate (ary .data ) )
163163
164164
165165@deserialize_container .register (DOFArray )
@@ -203,17 +203,27 @@ def randn(shape, dtype):
203203 rng = np .random .default_rng ()
204204 dtype = np .dtype (dtype )
205205
206+ if shape == 0 :
207+ ashape = 1
208+ else :
209+ ashape = shape
210+
206211 if dtype .kind == "c" :
207212 dtype = np .dtype (f"<f{ dtype .itemsize // 2 } " )
208- return rng .standard_normal (shape , dtype ) \
209- + 1j * rng .standard_normal (shape , dtype )
213+ r = rng .standard_normal (ashape , dtype ) \
214+ + 1j * rng .standard_normal (ashape , dtype )
210215 elif dtype .kind == "f" :
211- return rng .standard_normal (shape , dtype )
216+ r = rng .standard_normal (ashape , dtype )
212217 elif dtype .kind == "i" :
213- return rng .integers (0 , 128 , shape , dtype )
218+ r = rng .integers (0 , 512 , ashape , dtype )
214219 else :
215220 raise TypeError (dtype .kind )
216221
222+ if shape == 0 :
223+ return np .array (r [0 ])
224+
225+ return r
226+
217227
218228def assert_close_to_numpy (actx , op , args ):
219229 assert np .allclose (
@@ -672,11 +682,14 @@ def array_context(self):
672682 return self .mass .array_context
673683
674684
675- def _get_test_containers (actx , ambient_dim = 2 , size = 50_000 ):
676- if size == 0 :
677- x = DOFArray (actx , (actx .from_numpy (np .array (np .random .randn ())),))
678- else :
679- x = DOFArray (actx , (actx .from_numpy (np .random .randn (size )),))
685+ def _get_test_containers (actx , ambient_dim = 2 , shapes = 50_000 ):
686+ from numbers import Number
687+ if isinstance (shapes , (Number , tuple )):
688+ shapes = [shapes ]
689+
690+ x = DOFArray (actx , tuple ([
691+ actx .from_numpy (randn (shape , np .float64 ))
692+ for shape in shapes ]))
680693
681694 # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
682695 dataclass_of_dofs = MyContainer (
@@ -705,7 +718,7 @@ def _get_test_containers(actx, ambient_dim=2, size=50_000):
705718def test_container_scalar_map (actx_factory ):
706719 actx = actx_factory ()
707720
708- arys = _get_test_containers (actx , size = 0 )
721+ arys = _get_test_containers (actx , shapes = 0 )
709722 arys += (np .pi ,)
710723
711724 from arraycontext import (
@@ -877,16 +890,76 @@ def test_container_norm(actx_factory, ord):
877890# }}}
878891
879892
893+ # {{{ test flatten and unflatten
894+
895+ @pytest .mark .parametrize ("shapes" , [
896+ 0 , # tests device scalars when flattening
897+ 512 ,
898+ [(128 , 67 )],
899+ [(127 , 67 ), (18 , 0 )], # tests 0-sized arrays
900+ [(64 , 7 ), (154 , 12 )]
901+ ])
902+ def test_flatten_array_container (actx_factory , shapes ):
903+ if np .prod (shapes ) == 0 :
904+ # https://github.com/inducer/loopy/pull/497
905+ # NOTE: only fails for the pytato array context at the moment
906+ pytest .xfail ("strides do not match in subary" )
907+
908+ actx = actx_factory ()
909+
910+ from arraycontext import flatten , unflatten
911+ arys = _get_test_containers (actx , shapes = shapes )
912+
913+ for ary in arys :
914+ flat = flatten (ary , actx )
915+ assert flat .ndim == 1
916+
917+ ary_roundtrip = unflatten (ary , flat , actx )
918+
919+ from arraycontext import rec_multimap_reduce_array_container
920+ assert rec_multimap_reduce_array_container (
921+ np .prod ,
922+ lambda x , y : x .shape == y .shape ,
923+ ary , ary_roundtrip )
924+
925+ assert actx .to_numpy (
926+ actx .np .linalg .norm (ary - ary_roundtrip )
927+ ) < 1.0e-15
928+
929+
930+ def test_flatten_array_container_failure (actx_factory ):
931+ actx = actx_factory ()
932+
933+ from arraycontext import flatten , unflatten
934+ ary = _get_test_containers (actx , shapes = 512 )[0 ]
935+ flat_ary = flatten (ary , actx )
936+
937+ with pytest .raises (TypeError ):
938+ # cannot unflatten from a numpy array
939+ unflatten (ary , actx .to_numpy (flat_ary ), actx )
940+
941+ with pytest .raises (ValueError ):
942+ # cannot unflatten non-flat arrays
943+ unflatten (ary , flat_ary .reshape (2 , - 1 ), actx )
944+
945+ with pytest .raises (ValueError ):
946+ # cannot unflatten partially
947+ unflatten (ary , flat_ary [:- 1 ], actx )
948+
949+ # }}}
950+
951+
880952# {{{ test from_numpy and to_numpy
881953
882954def test_numpy_conversion (actx_factory ):
883955 actx = actx_factory ()
884956
957+ nelements = 42
885958 ac = MyContainer (
886959 name = "test_numpy_conversion" ,
887- mass = np .random .rand (42 ),
888- momentum = make_obj_array ([np .random .rand (42 ) for _ in range (3 )]),
889- enthalpy = np .random .rand (42 ),
960+ mass = np .random .rand (nelements , nelements ),
961+ momentum = make_obj_array ([np .random .rand (nelements ) for _ in range (3 )]),
962+ enthalpy = np .array ( np . random .rand () ),
890963 )
891964
892965 from arraycontext import from_numpy , to_numpy
0 commit comments