|
21 | 21 | from typing import Any, cast |
22 | 22 |
|
23 | 23 | from pytensor import function |
24 | | -from pytensor.graph import Apply |
25 | 24 | from pytensor.graph.basic import ancestors, walk |
26 | 25 | from pytensor.scalar.basic import Cast |
27 | 26 | from pytensor.tensor.elemwise import Elemwise |
28 | | -from pytensor.tensor.random.op import RandomVariable |
29 | 27 | from pytensor.tensor.shape import Shape |
30 | 28 | from pytensor.tensor.variable import TensorVariable |
31 | 29 |
|
@@ -240,42 +238,32 @@ class ModelGraph: |
240 | 238 | def __init__(self, model): |
241 | 239 | self.model = model |
242 | 240 | self._all_var_names = get_default_varnames(self.model.named_vars, include_transformed=False) |
| 241 | + self._all_vars = {model[var_name] for var_name in self._all_var_names} |
243 | 242 | self.var_list = self.model.named_vars.values() |
244 | 243 |
|
245 | 244 | def get_parent_names(self, var: TensorVariable) -> set[VarName]: |
246 | | - if var.owner is None or var.owner.inputs is None: |
| 245 | + if var.owner is None: |
247 | 246 | return set() |
248 | 247 |
|
249 | | - def _filter_non_parameter_inputs(var): |
250 | | - node = var.owner |
251 | | - if isinstance(node.op, Shape): |
252 | | - # Don't show shape-related dependencies |
253 | | - return [] |
254 | | - if isinstance(node.op, RandomVariable): |
255 | | - # Filter out rng and size parameters or RandomVariable nodes |
256 | | - return node.op.dist_params(node) |
257 | | - else: |
258 | | - # Otherwise return all inputs |
259 | | - return node.inputs |
260 | | - |
261 | | - blockers = set(self.model.named_vars) |
| 248 | + named_vars = self._all_vars |
262 | 249 |
|
263 | 250 | def _expand(x): |
264 | | - nonlocal blockers |
265 | | - if x.name in blockers: |
| 251 | + if x in named_vars: |
| 252 | + # Don't go beyond named_vars |
266 | 253 | return [x] |
267 | | - if isinstance(x.owner, Apply): |
268 | | - return reversed(_filter_non_parameter_inputs(x)) |
269 | | - return [] |
270 | | - |
271 | | - parents = set() |
272 | | - for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand): |
273 | | - # Only consider nodes that are in the named model variables. |
274 | | - vname = getattr(x, "name", None) |
275 | | - if isinstance(vname, str) and vname in self._all_var_names: |
276 | | - parents.add(VarName(vname)) |
277 | | - |
278 | | - return parents |
| 254 | + if x.owner is None: |
| 255 | + return [] |
| 256 | + if isinstance(x.owner.op, Shape): |
| 257 | + # Don't propagate shape-related dependencies |
| 258 | + return [] |
| 259 | + # Continue walking the graph through the inputs |
| 260 | + return x.owner.inputs |
| 261 | + |
| 262 | + return { |
| 263 | + cast(VarName, ancestor.name) # type: ignore[union-attr] |
| 264 | + for ancestor in walk(nodes=var.owner.inputs, expand=_expand) |
| 265 | + if ancestor in named_vars |
| 266 | + } |
279 | 267 |
|
280 | 268 | def vars_to_plot(self, var_names: Iterable[VarName] | None = None) -> list[VarName]: |
281 | 269 | if var_names is None: |
|
0 commit comments