|
87 | 87 | ShapeType, |
88 | 88 | _dtype_any, |
89 | 89 | array_dataclass, |
| 90 | + make_dict_of_named_arrays, |
90 | 91 | ) |
91 | 92 |
|
92 | 93 |
|
@@ -172,26 +173,28 @@ def replace_if_different(self, **kwargs: Any) -> Self: |
172 | 173 |
|
173 | 174 | @cached_property |
174 | 175 | def _placeholders(self) -> Mapping[str, Placeholder]: |
175 | | - from pytato.transform import InputGatherer |
176 | | - |
177 | | - mapper = InputGatherer() |
178 | | - |
179 | | - all_placeholders: frozenset[Placeholder] = frozenset() |
180 | | - for ary in self.returns.values(): |
181 | | - new_placeholders = frozenset({ |
182 | | - arg for arg in mapper(ary) |
183 | | - if isinstance(arg, Placeholder)}) |
184 | | - all_placeholders |= new_placeholders |
185 | | - |
186 | | - # FIXME: Need a way to check for *any* captured arrays, not just placeholders |
187 | | - if __debug__: |
188 | | - pl_names = frozenset(arg.name for arg in all_placeholders) |
189 | | - extra_pl_names = pl_names - self.parameters |
190 | | - assert not extra_pl_names, \ |
191 | | - f"Found non-argument placeholder '{next(iter(extra_pl_names))}' " \ |
192 | | - "in function definition." |
193 | | - |
194 | | - return constantdict({arg.name: arg for arg in all_placeholders}) |
| 176 | + from pytato.transform import ListOfInputsGatherer |
| 177 | + |
| 178 | + mapper = ListOfInputsGatherer() |
| 179 | + |
| 180 | + list_of_placeholders: list[Placeholder] = [ |
| 181 | + inp for inp in mapper(make_dict_of_named_arrays(self.returns)) |
| 182 | + if isinstance(inp, Placeholder)] |
| 183 | + |
| 184 | + placeholders: set[Placeholder] = set() |
| 185 | + for pl in list_of_placeholders: |
| 186 | + if pl.name not in self.parameters: |
| 187 | + # FIXME: Need a way to check for *any* captured arrays, not just |
| 188 | + # placeholders |
| 189 | + raise ValueError( |
| 190 | + f"Found non-argument placeholder '{pl}' in function definition.") |
| 191 | + if pl in placeholders: |
| 192 | + raise ValueError( |
| 193 | + f"Duplicated placeholder for argument '{pl.name}' in " |
| 194 | + "function definition.") |
| 195 | + placeholders.add(pl) |
| 196 | + |
| 197 | + return constantdict({pl.name: pl for pl in placeholders}) |
195 | 198 |
|
196 | 199 | def get_placeholder(self, name: str) -> Placeholder: |
197 | 200 | """ |
|
0 commit comments