Skip to content

Commit c6dae89

Browse files
committed
Move view_roots to the only file where it is used
1 parent d8beb7e commit c6dae89

File tree

3 files changed

+21
-26
lines changed

3 files changed

+21
-26
lines changed

pytensor/graph/basic.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,25 +1787,6 @@ def describe(r):
17871787
return [describe(output) for output in outputs]
17881788

17891789

1790-
def view_roots(node: Variable) -> list[Variable]:
1791-
"""Return the leaves from a search through consecutive view-maps."""
1792-
owner = node.owner
1793-
if owner is not None:
1794-
try:
1795-
vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
1796-
except AttributeError:
1797-
return [node]
1798-
if node in vars_to_views:
1799-
answer = []
1800-
for i in vars_to_views[node]:
1801-
answer += view_roots(owner.inputs[i])
1802-
return answer
1803-
else:
1804-
return [node]
1805-
else:
1806-
return [node]
1807-
1808-
18091790
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
18101791
"""Determine if any `depends_on` is in the graph given by ``apply``.
18111792

pytensor/tensor/blas.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
import numpy as np
8686
from scipy.linalg import get_blas_funcs
8787

88-
from pytensor.graph import vectorize_graph
88+
from pytensor.graph import Variable, vectorize_graph
8989
from pytensor.npy_2_compat import normalize_axis_tuple
9090

9191

@@ -97,7 +97,7 @@
9797

9898
import pytensor.scalar
9999
from pytensor.configdefaults import config
100-
from pytensor.graph.basic import Apply, view_roots
100+
from pytensor.graph.basic import Apply
101101
from pytensor.graph.op import Op
102102
from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError
103103
from pytensor.link.c.op import COp
@@ -114,6 +114,25 @@
114114
_logger = logging.getLogger("pytensor.tensor.blas")
115115

116116

117+
def view_roots(node: Variable) -> list[Variable]:
118+
"""Return the leaves from a search through consecutive view-maps."""
119+
owner = node.owner
120+
if owner is not None:
121+
try:
122+
vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()}
123+
except AttributeError:
124+
return [node]
125+
if node in vars_to_views:
126+
answer = []
127+
for i in vars_to_views[node]:
128+
answer += view_roots(owner.inputs[i])
129+
return answer
130+
else:
131+
return [node]
132+
else:
133+
return [node]
134+
135+
117136
def must_initialize_y_gemv():
118137
# Check whether Scipy GEMV could output nan if y in not initialized
119138
from scipy.linalg.blas import get_blas_funcs

tests/graph/test_basic.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -589,11 +589,6 @@ def test_io_connection_pattern():
589589
raise AssertionError()
590590

591591

592-
@pytest.mark.xfail(reason="Not implemented")
593-
def test_view_roots():
594-
raise AssertionError()
595-
596-
597592
def test_get_var_by_name():
598593
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
599594
o1 = MyOp(r1, r2)

0 commit comments

Comments
 (0)