Skip to content

Commit 5fdc130

Browse files
ferrinericardoV94
authored andcommitted
add variable_depends_on
1 parent 8ad3317 commit 5fdc130

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

pytensor/graph/basic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,28 @@ def apply_depends_on(apply: Apply, depends_on: Union[Apply, Collection[Apply]])
16031603
return False
16041604

16051605

1606+
def variable_depends_on(
1607+
variable: Variable, depends_on: Union[Variable, Collection[Variable]]
1608+
) -> bool:
1609+
"""Determine if any `depends_on` is in the graph given by ``variable``.
1610+
Parameters
1611+
----------
1612+
variable: Variable
1613+
Node to check
1614+
depends_on: Collection[Variable]
1615+
Nodes to check dependency on
1616+
1617+
Returns
1618+
-------
1619+
bool
1620+
"""
1621+
if not isinstance(depends_on, Collection):
1622+
depends_on = {depends_on}
1623+
else:
1624+
depends_on = set(depends_on)
1625+
return any(interim in depends_on for interim in ancestors([variable]))
1626+
1627+
16061628
def equal_computations(
16071629
xs: List[Union[np.ndarray, Variable]],
16081630
ys: List[Union[np.ndarray, Variable]],

tests/graph/test_basic.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
io_toposort,
2424
list_of_nodes,
2525
orphans_between,
26+
variable_depends_on,
2627
vars_between,
2728
walk,
2829
)
@@ -675,3 +676,22 @@ def test_NominalVariable_create_variable_type():
675676
assert type(ntv_unpkld) is type(ntv)
676677
assert ntv_unpkld.equals(ntv)
677678
assert ntv_unpkld is ntv
679+
680+
681+
def test_variable_depends_on():
682+
x = MyVariable(1)
683+
x.name = "x"
684+
y = MyVariable(1)
685+
y.name = "y"
686+
x2 = MyOp(x)
687+
x2.name = "x2"
688+
y2 = MyOp(y)
689+
y2.name = "y2"
690+
o = MyOp(x2, y)
691+
assert variable_depends_on(o, x)
692+
assert variable_depends_on(o, [x])
693+
assert not variable_depends_on(o, [y2])
694+
assert variable_depends_on(o, [y2, x])
695+
assert not variable_depends_on(y, [y2])
696+
assert variable_depends_on(y, [y])
697+

0 commit comments

Comments
 (0)