@@ -197,6 +197,84 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type:
197197# }}}
198198
199199
200+ # {{{ nested containers
201+
202+ @with_container_arithmetic (bcast_obj_array = False ,
203+ eq_comparison = False , rel_comparison = False ,
204+ _cls_has_array_context_attr = True )
205+ @dataclass_array_container
206+ @dataclass (frozen = True )
207+ class MyContainer :
208+ name : str
209+ mass : Union [DOFArray , np .ndarray ]
210+ momentum : np .ndarray
211+ enthalpy : Union [DOFArray , np .ndarray ]
212+
213+ @property
214+ def array_context (self ):
215+ if isinstance (self .mass , np .ndarray ):
216+ return next (iter (self .mass )).array_context
217+ else :
218+ return self .mass .array_context
219+
220+
221+ @with_container_arithmetic (
222+ bcast_obj_array = False ,
223+ bcast_container_types = (DOFArray , np .ndarray ),
224+ matmul = True ,
225+ rel_comparison = True ,
226+ _cls_has_array_context_attr = True )
227+ @dataclass_array_container
228+ @dataclass (frozen = True )
229+ class MyContainerDOFBcast :
230+ name : str
231+ mass : Union [DOFArray , np .ndarray ]
232+ momentum : np .ndarray
233+ enthalpy : Union [DOFArray , np .ndarray ]
234+
235+ @property
236+ def array_context (self ):
237+ if isinstance (self .mass , np .ndarray ):
238+ return next (iter (self .mass )).array_context
239+ else :
240+ return self .mass .array_context
241+
242+
243+ def _get_test_containers (actx , ambient_dim = 2 , shapes = 50_000 ):
244+ from numbers import Number
245+ if isinstance (shapes , (Number , tuple )):
246+ shapes = [shapes ]
247+
248+ x = DOFArray (actx , tuple ([
249+ actx .from_numpy (randn (shape , np .float64 ))
250+ for shape in shapes ]))
251+
252+ # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
253+ dataclass_of_dofs = MyContainer (
254+ name = "container" ,
255+ mass = x ,
256+ momentum = make_obj_array ([x ] * ambient_dim ),
257+ enthalpy = x )
258+
259+ # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
260+ bcast_dataclass_of_dofs = MyContainerDOFBcast (
261+ name = "container" ,
262+ mass = x ,
263+ momentum = make_obj_array ([x ] * ambient_dim ),
264+ enthalpy = x )
265+
266+ ary_dof = x
267+ ary_of_dofs = make_obj_array ([x ] * ambient_dim )
268+ mat_of_dofs = np .empty ((ambient_dim , ambient_dim ), dtype = object )
269+ for i in np .ndindex (mat_of_dofs .shape ):
270+ mat_of_dofs [i ] = x
271+
272+ return (ary_dof , ary_of_dofs , mat_of_dofs , dataclass_of_dofs ,
273+ bcast_dataclass_of_dofs )
274+
275+ # }}}
276+
277+
200278# {{{ assert_close_to_numpy*
201279
202280def randn (shape , dtype ):
@@ -341,6 +419,23 @@ def test_array_context_np_like(actx_factory, sym_name, n_args, dtype):
341419 assert_close_to_numpy (
342420 actx , lambda _np , * _args : getattr (_np , sym_name )(* _args ), args )
343421
422+ for c in (42.0 ,) + _get_test_containers (actx ):
423+ result = getattr (actx .np , sym_name )(c )
424+ result = actx .thaw (actx .freeze (result ))
425+
426+ if sym_name == "zeros_like" :
427+ if np .isscalar (result ):
428+ assert result == 0.0
429+ else :
430+ assert actx .to_numpy (actx .np .all (actx .np .equal (result , 0.0 )))
431+ elif sym_name == "ones_like" :
432+ if np .isscalar (result ):
433+ assert result == 1.0
434+ else :
435+ assert actx .to_numpy (actx .np .all (actx .np .equal (result , 1.0 )))
436+ else :
437+ raise ValueError (f"unknown method: '{ sym_name } '" )
438+
344439# }}}
345440
346441
@@ -671,79 +766,6 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec):
671766
672767# {{{ array container classes for test
673768
674- @with_container_arithmetic (bcast_obj_array = False ,
675- eq_comparison = False , rel_comparison = False ,
676- _cls_has_array_context_attr = True )
677- @dataclass_array_container
678- @dataclass (frozen = True )
679- class MyContainer :
680- name : str
681- mass : Union [DOFArray , np .ndarray ]
682- momentum : np .ndarray
683- enthalpy : Union [DOFArray , np .ndarray ]
684-
685- @property
686- def array_context (self ):
687- if isinstance (self .mass , np .ndarray ):
688- return next (iter (self .mass )).array_context
689- else :
690- return self .mass .array_context
691-
692-
693- @with_container_arithmetic (
694- bcast_obj_array = False ,
695- bcast_container_types = (DOFArray , np .ndarray ),
696- matmul = True ,
697- rel_comparison = True ,
698- _cls_has_array_context_attr = True )
699- @dataclass_array_container
700- @dataclass (frozen = True )
701- class MyContainerDOFBcast :
702- name : str
703- mass : Union [DOFArray , np .ndarray ]
704- momentum : np .ndarray
705- enthalpy : Union [DOFArray , np .ndarray ]
706-
707- @property
708- def array_context (self ):
709- if isinstance (self .mass , np .ndarray ):
710- return next (iter (self .mass )).array_context
711- else :
712- return self .mass .array_context
713-
714-
715- def _get_test_containers (actx , ambient_dim = 2 , shapes = 50_000 ):
716- from numbers import Number
717- if isinstance (shapes , (Number , tuple )):
718- shapes = [shapes ]
719-
720- x = DOFArray (actx , tuple ([
721- actx .from_numpy (randn (shape , np .float64 ))
722- for shape in shapes ]))
723-
724- # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
725- dataclass_of_dofs = MyContainer (
726- name = "container" ,
727- mass = x ,
728- momentum = make_obj_array ([x ] * ambient_dim ),
729- enthalpy = x )
730-
731- # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
732- bcast_dataclass_of_dofs = MyContainerDOFBcast (
733- name = "container" ,
734- mass = x ,
735- momentum = make_obj_array ([x ] * ambient_dim ),
736- enthalpy = x )
737-
738- ary_dof = x
739- ary_of_dofs = make_obj_array ([x ] * ambient_dim )
740- mat_of_dofs = np .empty ((ambient_dim , ambient_dim ), dtype = object )
741- for i in np .ndindex (mat_of_dofs .shape ):
742- mat_of_dofs [i ] = x
743-
744- return (ary_dof , ary_of_dofs , mat_of_dofs , dataclass_of_dofs ,
745- bcast_dataclass_of_dofs )
746-
747769
748770def test_container_scalar_map (actx_factory ):
749771 actx = actx_factory ()
0 commit comments