6969
7070from arraycontext .context import ArrayContext
7171from arraycontext .container import (
72- ContainerT , ArrayOrContainerT , is_array_container_type ,
72+ ContainerT , ArrayOrContainerT ,
7373 serialize_container , deserialize_container )
7474
7575
@@ -117,22 +117,32 @@ def _multimap_array_container_impl(
117117 specific container classes. By default, the recursion is stopped when
118118 a non-:class:`ArrayContainer` class is encountered.
119119 """
120+
121+ # {{{ recursive traversal
122+
120123 def rec (* _args : Any ) -> Any :
121124 template_ary = _args [container_indices [0 ]]
122- if (type (template_ary ) is leaf_cls
123- or not is_array_container_type (template_ary .__class__ )):
125+ if type (template_ary ) is leaf_cls :
124126 return f (* _args )
125127
128+ try :
129+ iterable_template = serialize_container (template_ary )
130+ except TypeError :
131+ return f (* _args )
132+ else :
133+ pass
134+
126135 assert all (
127136 type (_args [i ]) is type (template_ary ) for i in container_indices [1 :]
128137 ), f"expected type '{ type (template_ary ).__name__ } '"
129138
130139 result = []
131140 new_args = list (_args )
132141
133- for subarys in zip (* [
134- serialize_container (_args [i ]) for i in container_indices
135- ]):
142+ for subarys in zip (
143+ iterable_template ,
144+ * [serialize_container (_args [i ]) for i in container_indices [1 :]]
145+ ):
136146 key = None
137147 for i , (subkey , subary ) in zip (container_indices , subarys ):
138148 if key is None :
@@ -146,13 +156,36 @@ def rec(*_args: Any) -> Any:
146156
147157 return process_container (template_ary , result ) # type: ignore[operator]
148158
149- container_indices : List [int ] = [
150- i for i , arg in enumerate (args )
151- if is_array_container_type (arg .__class__ ) and type (arg ) is not leaf_cls ]
159+ # }}}
160+
161+ # {{{ find all containers in the argument list
162+
163+ container_indices : List [int ] = []
164+
165+ for i , arg in enumerate (args ):
166+ if type (arg ) is leaf_cls :
167+ continue
168+
169+ try :
170+ # FIXME: this will serialize again once `rec` is called, which is
171+ # not great, but it doesn't seem like there's a good way to avoid it
172+ _ = serialize_container (arg )
173+ except TypeError :
174+ pass
175+ else :
176+ container_indices .append (i )
177+
178+ # }}}
179+
180+ # {{{ #containers == 0 => call `f`
152181
153182 if not container_indices :
154183 return f (* args )
155184
185+ # }}}
186+
187+ # {{{ #containers == 1 => call `map_array_container`
188+
156189 if len (container_indices ) == 1 and reduce_func is None :
157190 # NOTE: if we just have one ArrayContainer in args, passing it through
158191 # _map_array_container_impl should be faster
@@ -167,9 +200,15 @@ def wrapper(ary: ArrayOrContainerT) -> ArrayOrContainerT:
167200 wrapper , template_ary ,
168201 leaf_cls = leaf_cls , recursive = recursive )
169202
203+ # }}}
204+
205+ # {{{ #containers > 1 => call `rec`
206+
170207 process_container = deserialize_container if reduce_func is None else reduce_func
171208 frec = rec if recursive else f
172209
210+ # }}}
211+
173212 return rec (* args )
174213
175214# }}}
0 commit comments