@@ -256,43 +256,51 @@ def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
256256
257257def rec_map_array_container (
258258 f : Callable [[Any ], Any ],
259- ary : ArrayOrContainerT ) -> ArrayOrContainerT :
259+ ary : ArrayOrContainerT ,
260+ leaf_class : Optional [type ] = None ) -> ArrayOrContainerT :
260261 r"""Applies *f* recursively to an :class:`ArrayContainer`.
261262
262263 For a non-recursive version see :func:`map_array_container`.
263264
264265 :param ary: a (potentially nested) structure of :class:`ArrayContainer`\ s,
265266 or an instance of a base array type.
266267 """
267- return _map_array_container_impl (f , ary , recursive = True )
268+ return _map_array_container_impl (f , ary , leaf_cls = leaf_class , recursive = True )
268269
269270
270271def mapped_over_array_containers (
271- f : Callable [[Any ], Any ]) -> Callable [[ArrayOrContainerT ], ArrayOrContainerT ]:
272+ f : Callable [[Any ], Any ],
273+ leaf_class : Optional [type ] = None ) -> Callable [
274+ [ArrayOrContainerT ], ArrayOrContainerT ]:
272275 """Decorator around :func:`rec_map_array_container`."""
273- wrapper = partial (rec_map_array_container , f )
276+ wrapper = partial (rec_map_array_container , f , leaf_class = leaf_class )
274277 update_wrapper (wrapper , f )
275278 return wrapper
276279
277280
278- def rec_multimap_array_container (f : Callable [..., Any ], * args : Any ) -> Any :
281+ def rec_multimap_array_container (
282+ f : Callable [..., Any ],
283+ * args : Any ,
284+ leaf_class : Optional [type ] = None ) -> Any :
279285 r"""Applies *f* recursively to multiple :class:`ArrayContainer`\ s.
280286
281287 For a non-recursive version see :func:`multimap_array_container`.
282288
283289 :param args: all :class:`ArrayContainer` arguments must be of the same
284290 type and with the same structure (same number of components, etc.).
285291 """
286- return _multimap_array_container_impl (f , * args , recursive = True )
292+ return _multimap_array_container_impl (
293+ f , * args , leaf_cls = leaf_class , recursive = True )
287294
288295
289296def multimapped_over_array_containers (
290- f : Callable [..., Any ]) -> Callable [..., Any ]:
297+ f : Callable [..., Any ],
298+ leaf_class : Optional [type ] = None ) -> Callable [..., Any ]:
291299 """Decorator around :func:`rec_multimap_array_container`."""
292300 # can't use functools.partial, because its result is insufficiently
293301 # function-y to be used as a method definition.
294302 def wrapper (* args : Any ) -> Any :
295- return rec_multimap_array_container (f , * args )
303+ return rec_multimap_array_container (f , * args , leaf_class = leaf_class )
296304
297305 update_wrapper (wrapper , f )
298306 return wrapper
@@ -401,7 +409,8 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
401409def rec_map_reduce_array_container (
402410 reduce_func : Callable [[Iterable [Any ]], Any ],
403411 map_func : Callable [[Any ], Any ],
404- ary : ArrayOrContainerT ) -> "DeviceArray" :
412+ ary : ArrayOrContainerT ,
413+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
405414 """Perform a map-reduce over array containers recursively.
406415
407416 :param reduce_func: callable used to reduce over the components of *ary*
@@ -440,22 +449,26 @@ def rec_map_reduce_array_container(
440449 or any other such traversal.
441450 """
442451 def rec (_ary : ArrayOrContainerT ) -> ArrayOrContainerT :
443- try :
444- iterable = serialize_container (_ary )
445- except NotAnArrayContainerError :
452+ if type (_ary ) is leaf_class :
446453 return map_func (_ary )
447454 else :
448- return reduce_func ([
449- rec (subary ) for _ , subary in iterable
450- ])
455+ try :
456+ iterable = serialize_container (_ary )
457+ except NotAnArrayContainerError :
458+ return map_func (_ary )
459+ else :
460+ return reduce_func ([
461+ rec (subary ) for _ , subary in iterable
462+ ])
451463
452464 return rec (ary )
453465
454466
455467def rec_multimap_reduce_array_container (
456468 reduce_func : Callable [[Iterable [Any ]], Any ],
457469 map_func : Callable [..., Any ],
458- * args : Any ) -> "DeviceArray" :
470+ * args : Any ,
471+ leaf_class : Optional [type ] = None ) -> "DeviceArray" :
459472 r"""Perform a map-reduce over multiple array containers recursively.
460473
461474 :param reduce_func: callable used to reduce over the components of any
@@ -478,7 +491,7 @@ def _reduce_wrapper(ary: ContainerT, iterable: Iterable[Tuple[Any, Any]]) -> Any
478491
479492 return _multimap_array_container_impl (
480493 map_func , * args ,
481- reduce_func = _reduce_wrapper , leaf_cls = None , recursive = True )
494+ reduce_func = _reduce_wrapper , leaf_cls = leaf_class , recursive = True )
482495
483496# }}}
484497
0 commit comments