Skip to content

Commit 13e9a77

Browse files
alexfiklinducer
authored andcommitted
remove is_array_container_type in multimap
1 parent d03855d commit 13e9a77

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

arraycontext/container/traversal.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969

7070
from arraycontext.context import ArrayContext
7171
from arraycontext.container import (
72-
ContainerT, ArrayOrContainerT, is_array_container_type,
72+
ContainerT, ArrayOrContainerT,
7373
serialize_container, deserialize_container)
7474

7575

@@ -117,22 +117,32 @@ def _multimap_array_container_impl(
117117
specific container classes. By default, the recursion is stopped when
118118
a non-:class:`ArrayContainer` class is encountered.
119119
"""
120+
121+
# {{{ recursive traversal
122+
120123
def rec(*_args: Any) -> Any:
121124
template_ary = _args[container_indices[0]]
122-
if (type(template_ary) is leaf_cls
123-
or not is_array_container_type(template_ary.__class__)):
125+
if type(template_ary) is leaf_cls:
124126
return f(*_args)
125127

128+
try:
129+
iterable_template = serialize_container(template_ary)
130+
except TypeError:
131+
return f(*_args)
132+
else:
133+
pass
134+
126135
assert all(
127136
type(_args[i]) is type(template_ary) for i in container_indices[1:]
128137
), f"expected type '{type(template_ary).__name__}'"
129138

130139
result = []
131140
new_args = list(_args)
132141

133-
for subarys in zip(*[
134-
serialize_container(_args[i]) for i in container_indices
135-
]):
142+
for subarys in zip(
143+
iterable_template,
144+
*[serialize_container(_args[i]) for i in container_indices[1:]]
145+
):
136146
key = None
137147
for i, (subkey, subary) in zip(container_indices, subarys):
138148
if key is None:
@@ -146,13 +156,36 @@ def rec(*_args: Any) -> Any:
146156

147157
return process_container(template_ary, result) # type: ignore[operator]
148158

149-
container_indices: List[int] = [
150-
i for i, arg in enumerate(args)
151-
if is_array_container_type(arg.__class__) and type(arg) is not leaf_cls]
159+
# }}}
160+
161+
# {{{ find all containers in the argument list
162+
163+
container_indices: List[int] = []
164+
165+
for i, arg in enumerate(args):
166+
if type(arg) is leaf_cls:
167+
continue
168+
169+
try:
170+
# FIXME: this will serialize again once `rec` is called, which is
171+
# not great, but it doesn't seem like there's a good way to avoid it
172+
_ = serialize_container(arg)
173+
except TypeError:
174+
pass
175+
else:
176+
container_indices.append(i)
177+
178+
# }}}
179+
180+
# {{{ #containers == 0 => call `f`
152181

153182
if not container_indices:
154183
return f(*args)
155184

185+
# }}}
186+
187+
# {{{ #containers == 1 => call `map_array_container`
188+
156189
if len(container_indices) == 1 and reduce_func is None:
157190
# NOTE: if we just have one ArrayContainer in args, passing it through
158191
# _map_array_container_impl should be faster
@@ -167,9 +200,15 @@ def wrapper(ary: ArrayOrContainerT) -> ArrayOrContainerT:
167200
wrapper, template_ary,
168201
leaf_cls=leaf_cls, recursive=recursive)
169202

203+
# }}}
204+
205+
# {{{ #containers > 1 => call `rec`
206+
170207
process_container = deserialize_container if reduce_func is None else reduce_func
171208
frec = rec if recursive else f
172209

210+
# }}}
211+
173212
return rec(*args)
174213

175214
# }}}

arraycontext/impl/pytato/compile.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828
THE SOFTWARE.
2929
"""
3030

31-
from arraycontext.container import ArrayContainer
31+
from arraycontext.container import ArrayContainer, is_array_container_type
3232
from arraycontext import PytatoPyOpenCLArrayContext
33-
from arraycontext.container.traversal import (rec_keyed_map_array_container,
34-
is_array_container_type)
33+
from arraycontext.container.traversal import rec_keyed_map_array_container
3534

3635
import numpy as np
3736
from typing import Any, Callable, Tuple, Dict, Mapping
@@ -162,8 +161,8 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
162161
def _rec_to_placeholder(keys, ary):
163162
name = arg_id_to_name[(kw,) + keys]
164163
return pt.make_placeholder(name, ary.shape, ary.dtype)
165-
return rec_keyed_map_array_container(_rec_to_placeholder,
166-
arg)
164+
165+
return rec_keyed_map_array_container(_rec_to_placeholder, arg)
167166
else:
168167
raise NotImplementedError(type(arg))
169168

0 commit comments

Comments
 (0)