Skip to content

Commit 9a5deee

Browse files
committed
Don't run MergeOptimization in Composite.fgraph
This would trigger it for every Composite/ScalarLoop present in the C-cache
1 parent d9b494d commit 9a5deee

File tree

4 files changed

+70
-85
lines changed

4 files changed

+70
-85
lines changed

pytensor/scalar/basic.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3998,6 +3998,42 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
39983998
def __init__(self, *args, **kwargs):
39993999
self.prepare_node_called = set()
40004000

4001+
def _cleanup_graph(self, inputs, outputs):
4002+
# TODO: We could convert to TensorVariable, optimize graph,
4003+
# and then convert back to ScalarVariable.
4004+
# This would introduce rewrites like `log(1 + x) -> log1p`.
4005+
4006+
fgraph = FunctionGraph(copy(inputs), copy(outputs))
4007+
4008+
# Validate node types
4009+
for node in fgraph.apply_nodes:
4010+
if not isinstance(node.op, ScalarOp):
4011+
raise TypeError(
4012+
f"The fgraph of {self.__class__.__name__} must be exclusively "
4013+
"composed of scalar operations."
4014+
)
4015+
4016+
# Run MergeOptimization to avoid duplicated nodes
4017+
MergeOptimizer().rewrite(fgraph)
4018+
4019+
inputs, outputs = fgraph.inputs, fgraph.outputs
4020+
4021+
# Clone identical outputs that may have been merged
4022+
# If fgraph.outputs = [out_A, out_B, out_A], then final outputs = [out_A, out_B, clone(out_A)]
4023+
if len(set(fgraph.outputs)) != len(outputs):
4024+
old_outputs = outputs
4025+
outputs = []
4026+
for old_output in old_outputs:
4027+
if old_output not in outputs:
4028+
outputs.append(old_output)
4029+
else:
4030+
node = old_output.owner
4031+
output_idx = node.outputs.index(old_output)
4032+
output = node.clone().outputs[output_idx]
4033+
outputs.append(output)
4034+
4035+
return inputs, outputs
4036+
40014037
@property
40024038
def fn(self):
40034039
return None
@@ -4187,10 +4223,9 @@ def __init__(self, inputs, outputs, name="Composite"):
41874223
assert res[0] != inputs
41884224
inputs, outputs = res[0], res2[1]
41894225

4190-
self.inputs = copy(inputs)
4191-
self.outputs = copy(outputs)
4192-
self.inputs_type = tuple([input.type for input in inputs])
4193-
self.outputs_type = tuple([output.type for output in outputs])
4226+
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
4227+
self.inputs_type = tuple([input.type for input in self.inputs])
4228+
self.outputs_type = tuple([output.type for output in self.outputs])
41944229
self.nin = len(inputs)
41954230
self.nout = len(outputs)
41964231
super().__init__()
@@ -4237,34 +4272,9 @@ def make_new_inplace(self, output_types_preference=None, name=None):
42374272
def fgraph(self):
42384273
if hasattr(self, "_fgraph"):
42394274
return self._fgraph
4240-
4241-
# The clone done by FunctionGraph is needed as we don't want
4242-
# the fgraph to be set to the variable as we need to pickle
4243-
# them for the cache of c module to work.
4275+
# fgraph cannot be a property of the base class because it messes up with C caching.
4276+
# We also need a `FunctionGraph(clone=True)` (default) according to an old comment
42444277
fgraph = FunctionGraph(self.inputs, self.outputs)
4245-
with config.change_flags(optimizer_verbose=False):
4246-
MergeOptimizer().rewrite(fgraph)
4247-
for node in fgraph.apply_nodes:
4248-
if not isinstance(node.op, ScalarOp):
4249-
raise TypeError(
4250-
"The fgraph to Composite must be exclusively"
4251-
" composed of ScalarOp instances."
4252-
)
4253-
4254-
# Clone identical outputs that have been merged
4255-
if len(set(fgraph.outputs)) != len(self.outputs):
4256-
old_outputs = fgraph.outputs
4257-
new_outputs = []
4258-
for output in old_outputs:
4259-
if output not in new_outputs:
4260-
new_outputs.append(output)
4261-
else:
4262-
node = output.owner
4263-
output_idx = node.outputs.index(output)
4264-
new_output = node.clone().outputs[output_idx]
4265-
new_outputs.append(new_output)
4266-
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
4267-
42684278
self._fgraph = fgraph
42694279
return self._fgraph
42704280

@@ -4389,7 +4399,7 @@ def c_code(self, node, nodename, inames, onames, sub):
43894399
return self.c_code_template % d
43904400

43914401
def c_code_cache_version_outer(self) -> Tuple[int, ...]:
4392-
return (3,)
4402+
return (4,)
43934403

43944404

43954405
class Compositef32:

pytensor/scalar/loop.py

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from copy import copy
21
from itertools import chain
3-
from typing import Optional, Sequence, Tuple, cast
2+
from typing import Optional, Sequence, Tuple
43

54
from pytensor.compile import rebuild_collect_shared
65
from pytensor.graph import Constant, FunctionGraph, Variable, clone
7-
from pytensor.graph.rewriting.basic import MergeOptimizer
8-
from pytensor.scalar.basic import ScalarInnerGraphOp, ScalarOp, as_scalar
6+
from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar
97

108

119
class ScalarLoop(ScalarInnerGraphOp):
@@ -62,44 +60,38 @@ def __init__(
6260
if not len(init) == len(update):
6361
raise ValueError("An update must be given for each init variable")
6462
if until:
65-
inputs, (*outputs, until) = clone([*init, *constant], [*update, until])
66-
self.outputs = copy([*outputs, until])
63+
inputs, outputs = clone([*init, *constant], [*update, until])
6764
else:
6865
inputs, outputs = clone([*init, *constant], update)
69-
self.outputs = copy(outputs)
70-
self.inputs = copy(inputs)
7166

7267
self.is_while = bool(until)
73-
self.inputs_type = tuple(input.type for input in inputs)
74-
self.outputs_type = tuple(output.type for output in outputs)
75-
if self.is_while:
76-
self.outputs_type = self.outputs_type + (cast(Variable, until).type,)
77-
self.nin = len(inputs) + 1 # n_steps is not part of the inner graph
78-
self.nout = len(outputs) + (1 if self.is_while else 0)
68+
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
69+
self._validate_updates(self.inputs, self.outputs)
70+
71+
self.inputs_type = tuple(input.type for input in self.inputs)
72+
self.outputs_type = tuple(output.type for output in self.outputs)
73+
self.nin = len(self.inputs) + 1 # n_steps is not part of the inner graph
74+
self.nout = len(self.outputs)
7975
self.name = name
80-
self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False))
76+
8177
super().__init__()
8278

8379
def output_types(self, input_types):
8480
return self.outputs_type
8581

86-
def _validate_fgraph(self, fgraph: FunctionGraph) -> None:
87-
for node in fgraph.apply_nodes:
88-
if not isinstance(node.op, ScalarOp):
89-
raise TypeError(
90-
"The fgraph of ScalarLoop must be composed exclusively of ScalarOp nodes"
91-
)
92-
93-
init = fgraph.inputs
94-
update = fgraph.outputs
95-
82+
def _validate_updates(
83+
self, inputs: Sequence[Variable], outputs: Sequence[Variable]
84+
) -> None:
85+
init = inputs
86+
update: Sequence[Variable]
9687
if self.is_while:
97-
*update, until = update
88+
*update, until = outputs
9889
if not until.type.dtype == "bool":
9990
raise TypeError(
10091
f"Until condition must be boolean, got {until}({until.type.dtype})"
10192
)
102-
93+
else:
94+
update = outputs
10395
for i, u in zip(init, update):
10496
if i.type != u.type:
10597
raise TypeError(
@@ -116,28 +108,9 @@ def _validate_fgraph(self, fgraph: FunctionGraph) -> None:
116108
def fgraph(self):
117109
if hasattr(self, "_fgraph"):
118110
return self._fgraph
119-
111+
# fgraph cannot be a property of the base class because it messes up with C caching.
112+
# We also need a `FunctionGraph(clone=True)` (default) according to an old comment
120113
fgraph = FunctionGraph(self.inputs, self.outputs)
121-
# TODO: We could convert to TensorVariable, optimize graph,
122-
# and then convert back to ScalarVariable.
123-
# This would introduce rewrites like `log(1 + x) -> log1p`.
124-
MergeOptimizer().rewrite(fgraph)
125-
self._validate_fgraph(fgraph)
126-
127-
# Clone identical outputs that have been merged
128-
if len(set(fgraph.outputs)) != len(self.outputs):
129-
old_outputs = fgraph.outputs
130-
new_outputs = []
131-
for output in old_outputs:
132-
if output not in new_outputs:
133-
new_outputs.append(output)
134-
else:
135-
node = output.owner
136-
output_idx = node.outputs.index(output)
137-
new_output = node.clone().outputs[output_idx]
138-
new_outputs.append(new_output)
139-
fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False)
140-
141114
self._fgraph = fgraph
142115
return self._fgraph
143116

tests/scalar/test_basic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,11 @@ def test_composite_printing(self):
200200

201201
def test_non_scalar_error(self):
202202
x = float32("x")
203-
comp_op = Composite([x], [(at.zeros((2,)) + x).sum()])
204-
205-
with pytest.raises(TypeError, match=".*exclusively.*ScalarOp.*"):
206-
comp_op.fgraph
203+
with pytest.raises(
204+
TypeError,
205+
match="The fgraph of Composite must be exclusively composed of scalar operations",
206+
):
207+
Composite([x], [(at.zeros((2,)) + x).sum()])
207208

208209
def test_multi_out_perform(self):
209210
from pytensor.graph.basic import Apply

tests/scalar/test_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def test_non_scalar_error():
151151
x = as_scalar(tensor_exp(x0))
152152

153153
with pytest.raises(
154-
TypeError, match="must be composed exclusively of ScalarOp nodes"
154+
TypeError,
155+
match="The fgraph of ScalarLoop must be exclusively composed of scalar operations",
155156
):
156157
ScalarLoop(init=[x0], constant=[], update=[x])
157158

0 commit comments

Comments
 (0)