Skip to content

Commit a39e5b8

Browse files
majosminducer
authored andcommitted
use ListOfInputsGatherer in FunctionDefinition's placeholder detection code
InputGatherer is too strict, as it will error on any duplicated nodes in the function body
1 parent 4340fa0 commit a39e5b8

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

pytato/function.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
ShapeType,
8888
_dtype_any,
8989
array_dataclass,
90+
make_dict_of_named_arrays,
9091
)
9192

9293

@@ -172,26 +173,28 @@ def replace_if_different(self, **kwargs: Any) -> Self:
172173

173174
@cached_property
174175
def _placeholders(self) -> Mapping[str, Placeholder]:
175-
from pytato.transform import InputGatherer
176-
177-
mapper = InputGatherer()
178-
179-
all_placeholders: frozenset[Placeholder] = frozenset()
180-
for ary in self.returns.values():
181-
new_placeholders = frozenset({
182-
arg for arg in mapper(ary)
183-
if isinstance(arg, Placeholder)})
184-
all_placeholders |= new_placeholders
185-
186-
# FIXME: Need a way to check for *any* captured arrays, not just placeholders
187-
if __debug__:
188-
pl_names = frozenset(arg.name for arg in all_placeholders)
189-
extra_pl_names = pl_names - self.parameters
190-
assert not extra_pl_names, \
191-
f"Found non-argument placeholder '{next(iter(extra_pl_names))}' " \
192-
"in function definition."
193-
194-
return constantdict({arg.name: arg for arg in all_placeholders})
176+
from pytato.transform import ListOfInputsGatherer
177+
178+
mapper = ListOfInputsGatherer()
179+
180+
list_of_placeholders: list[Placeholder] = [
181+
inp for inp in mapper(make_dict_of_named_arrays(self.returns))
182+
if isinstance(inp, Placeholder)]
183+
184+
placeholders: set[Placeholder] = set()
185+
for pl in list_of_placeholders:
186+
if pl.name not in self.parameters:
187+
# FIXME: Need a way to check for *any* captured arrays, not just
188+
# placeholders
189+
raise ValueError(
190+
f"Found non-argument placeholder '{pl}' in function definition.")
191+
if pl in placeholders:
192+
raise ValueError(
193+
f"Duplicated placeholder for argument '{pl.name}' in "
194+
"function definition.")
195+
placeholders.add(pl)
196+
197+
return constantdict({pl.name: pl for pl in placeholders})
195198

196199
def get_placeholder(self, name: str) -> Placeholder:
197200
"""

0 commit comments

Comments
 (0)