Skip to content

Commit 2778160

Browse files
committed
Rename replace/vectorize to replace/vectorize_graph
1 parent 902eeb6 commit 2778160

File tree

4 files changed

+17
-11
lines changed

4 files changed

+17
-11
lines changed

pytensor/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
clone,
1010
ancestors,
1111
)
12-
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
12+
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
1313
from pytensor.graph.op import Op
1414
from pytensor.graph.type import Type
1515
from pytensor.graph.fg import FunctionGraph

pytensor/graph/replace.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from collections.abc import Iterable, Mapping, Sequence
23
from functools import partial, singledispatch
34
from typing import Optional, Union, cast, overload
@@ -215,22 +216,22 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
215216

216217

217218
@overload
218-
def vectorize(
219+
def vectorize_graph(
219220
outputs: Variable,
220221
replace: Mapping[Variable, Variable],
221222
) -> Variable:
222223
...
223224

224225

225226
@overload
226-
def vectorize(
227+
def vectorize_graph(
227228
outputs: Sequence[Variable],
228229
replace: Mapping[Variable, Variable],
229230
) -> Sequence[Variable]:
230231
...
231232

232233

233-
def vectorize(
234+
def vectorize_graph(
234235
outputs: Union[Variable, Sequence[Variable]],
235236
replace: Mapping[Variable, Variable],
236237
) -> Union[Variable, Sequence[Variable]]:
@@ -309,3 +310,8 @@ def transform(var: Variable) -> Variable:
309310
else:
310311
[vect_output] = seq_vect_outputs
311312
return vect_output
313+
314+
315+
def vectorize(*args, **kwargs):
316+
warnings.warn("vectorize was renamed to vectorize_graph", UserWarning)
317+
return vectorize_node(*args, **kwargs)

pytensor/tensor/blockwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.graph.basic import Apply, Constant, Variable
1010
from pytensor.graph.null_type import NullType
1111
from pytensor.graph.op import Op
12-
from pytensor.graph.replace import _vectorize_node, vectorize
12+
from pytensor.graph.replace import _vectorize_node, vectorize_graph
1313
from pytensor.tensor import as_tensor_variable
1414
from pytensor.tensor.shape import shape_padleft
1515
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
@@ -274,7 +274,7 @@ def as_core(t, core_t):
274274

275275
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
276276

277-
igrads = vectorize(
277+
igrads = vectorize_graph(
278278
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
279279
replace=dict(
280280
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)

tests/graph/test_replace.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytensor.tensor as pt
66
from pytensor import config, function, shared
77
from pytensor.graph.basic import equal_computations, graph_inputs
8-
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
8+
from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph
99
from pytensor.tensor import dvector, fvector, vector
1010
from tests import unittest_tools as utt
1111
from tests.graph.utils import MyOp, MyVariable
@@ -226,18 +226,18 @@ def test_graph_replace_disconnected(self):
226226
oc = graph_replace([o], {fake: x.clone()}, strict=True)
227227

228228

229-
class TestVectorize:
229+
class TestVectorizeGraph:
230230
# TODO: Add tests with multiple outputs, constants, and other singleton types
231231

232232
def test_basic(self):
233233
x = pt.vector("x")
234234
y = pt.exp(x) / pt.sum(pt.exp(x))
235235

236236
new_x = pt.matrix("new_x")
237-
[new_y] = vectorize([y], {x: new_x})
237+
[new_y] = vectorize_graph([y], {x: new_x})
238238

239239
# Check we can pass both a sequence or a single variable
240-
alt_new_y = vectorize(y, {x: new_x})
240+
alt_new_y = vectorize_graph(y, {x: new_x})
241241
assert equal_computations([new_y], [alt_new_y])
242242

243243
fn = function([new_x], new_y)
@@ -253,7 +253,7 @@ def test_multiple_outputs(self):
253253
y2 = x[-1]
254254

255255
new_x = pt.matrix("new_x")
256-
[new_y1, new_y2] = vectorize([y1, y2], {x: new_x})
256+
[new_y1, new_y2] = vectorize_graph([y1, y2], {x: new_x})
257257

258258
fn = function([new_x], [new_y1, new_y2])
259259
new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX)

0 commit comments

Comments
 (0)