@@ -756,6 +756,59 @@ def test_container_scalar_map(actx_factory):
756756 assert result is not None
757757
758758
759+ def test_container_map (actx_factory ):
760+ actx = actx_factory ()
761+ ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs , bcast_dc_of_dofs = \
762+ _get_test_containers (actx )
763+
764+ # {{{ check
765+
766+ def _check_allclose (f , arg1 , arg2 , atol = 2.0e-14 ):
767+ from arraycontext import NotAnArrayContainerError
768+ try :
769+ arg1_iterable = serialize_container (arg1 )
770+ arg2_iterable = serialize_container (arg2 )
771+ except NotAnArrayContainerError :
772+ assert np .linalg .norm (actx .to_numpy (f (arg1 ) - arg2 )) < atol
773+ else :
774+ arg1_subarrays = [
775+ subarray for _ , subarray in arg1_iterable ]
776+ arg2_subarrays = [
777+ subarray for _ , subarray in arg2_iterable ]
778+ for subarray1 , subarray2 in zip (arg1_subarrays , arg2_subarrays ):
779+ _check_allclose (f , subarray1 , subarray2 )
780+
781+ def func (x ):
782+ return x + 1
783+
784+ from arraycontext import rec_map_array_container
785+ result = rec_map_array_container (func , 1 )
786+ assert result == 2
787+
788+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
789+ result = rec_map_array_container (func , ary )
790+ _check_allclose (func , ary , result )
791+
792+ from arraycontext import mapped_over_array_containers
793+
794+ @mapped_over_array_containers
795+ def mapped_func (x ):
796+ return func (x )
797+
798+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
799+ result = mapped_func (ary )
800+ _check_allclose (func , ary , result )
801+
802+ @mapped_over_array_containers (leaf_class = DOFArray )
803+ def check_leaf (x ):
804+ assert isinstance (x , DOFArray )
805+
806+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
807+ check_leaf (ary )
808+
809+ # }}}
810+
811+
759812def test_container_multimap (actx_factory ):
760813 actx = actx_factory ()
761814 ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs , bcast_dc_of_dofs = \
@@ -764,7 +817,19 @@ def test_container_multimap(actx_factory):
764817 # {{{ check
765818
766819 def _check_allclose (f , arg1 , arg2 , atol = 2.0e-14 ):
767- assert np .linalg .norm (actx .to_numpy (f (arg1 ) - arg2 )) < atol
820+ from arraycontext import NotAnArrayContainerError
821+ try :
822+ arg1_iterable = serialize_container (arg1 )
823+ arg2_iterable = serialize_container (arg2 )
824+ except NotAnArrayContainerError :
825+ assert np .linalg .norm (actx .to_numpy (f (arg1 ) - arg2 )) < atol
826+ else :
827+ arg1_subarrays = [
828+ subarray for _ , subarray in arg1_iterable ]
829+ arg2_subarrays = [
830+ subarray for _ , subarray in arg2_iterable ]
831+ for subarray1 , subarray2 in zip (arg1_subarrays , arg2_subarrays ):
832+ _check_allclose (f , subarray1 , subarray2 )
768833
769834 def func_all_scalar (x , y ):
770835 return x + y
@@ -779,17 +844,30 @@ def func_multiple_scalar(a, subary1, b, subary2):
779844 result = rec_multimap_array_container (func_all_scalar , 1 , 2 )
780845 assert result == 3
781846
782- from functools import partial
783847 for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
784848 result = rec_multimap_array_container (func_first_scalar , 1 , ary )
785- rec_multimap_array_container (
786- partial (_check_allclose , lambda x : 1 + x ),
787- ary , result )
849+ _check_allclose (lambda x : 1 + x , ary , result )
788850
789851 result = rec_multimap_array_container (func_multiple_scalar , 2 , ary , 2 , ary )
790- rec_multimap_array_container (
791- partial (_check_allclose , lambda x : 4 * x ),
792- ary , result )
852+ _check_allclose (lambda x : 4 * x , ary , result )
853+
854+ from arraycontext import multimapped_over_array_containers
855+
856+ @multimapped_over_array_containers
857+ def mapped_func (a , subary1 , b , subary2 ):
858+ return func_multiple_scalar (a , subary1 , b , subary2 )
859+
860+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
861+ result = mapped_func (2 , ary , 2 , ary )
862+ _check_allclose (lambda x : 4 * x , ary , result )
863+
864+ @multimapped_over_array_containers (leaf_class = DOFArray )
865+ def check_leaf (a , subary1 , b , subary2 ):
866+ assert isinstance (subary1 , DOFArray )
867+ assert isinstance (subary2 , DOFArray )
868+
869+ for ary in [ary_dof , ary_of_dofs , mat_of_dofs , dc_of_dofs ]:
870+ check_leaf (2 , ary , 2 , ary )
793871
794872 with pytest .raises (AssertionError ):
795873 rec_multimap_array_container (func_multiple_scalar , 2 , ary_dof , 2 , dc_of_dofs )
0 commit comments