Skip to content

Commit 47efd4b

Browse files
committed
Allow single variable output in vectorize
Also: * rename `vectorize` kwarg by `replace` * add test for multiple outputs
1 parent 30e08e2 commit 47efd4b

File tree

3 files changed

+75
-10
lines changed

3 files changed

+75
-10
lines changed

pytensor/graph/replace.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,26 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
213213
return _vectorize_node(op, node, *batched_inputs)
214214

215215

216+
@overload
217+
def vectorize(
218+
outputs: Variable,
219+
replace: Mapping[Variable, Variable],
220+
) -> Variable:
221+
...
222+
223+
224+
@overload
216225
def vectorize(
217-
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable]
226+
outputs: Sequence[Variable],
227+
replace: Mapping[Variable, Variable],
218228
) -> Sequence[Variable]:
229+
...
230+
231+
232+
def vectorize(
233+
outputs: Union[Variable, Sequence[Variable]],
234+
replace: Mapping[Variable, Variable],
235+
) -> Union[Variable, Sequence[Variable]]:
219236
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
220237
221238
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
@@ -235,20 +252,44 @@ def vectorize(
235252
236253
# Vectorized graph
237254
new_x = pt.matrix("new_x")
238-
[new_y] = vectorize([y], {x: new_x})
255+
new_y = vectorize(y, replace={x: new_x})
239256
240257
fn = pytensor.function([new_x], new_y)
241258
fn([[0, 1, 2], [2, 1, 0]])
242259
# array([[0.09003057, 0.24472847, 0.66524096],
243260
# [0.66524096, 0.24472847, 0.09003057]])
244261
262+
263+
.. code-block:: python
264+
265+
import pytensor
266+
import pytensor.tensor as pt
267+
268+
from pytensor.graph import vectorize
269+
270+
# Original graph
271+
x = pt.vector("x")
272+
y1 = x[0]
273+
y2 = x[-1]
274+
275+
# Vectorized graph
276+
new_x = pt.matrix("new_x")
277+
[new_y1, new_y2] = vectorize([y1, y2], replace={x: new_x})
278+
279+
fn = pytensor.function([new_x], [new_y1, new_y2])
280+
fn([[-10, 0, 10], [-11, 0, 11]])
281+
# [array([-10., -11.]), array([10., 11.])]
282+
245283
"""
246-
# Avoid circular import
284+
if isinstance(outputs, Sequence):
285+
seq_outputs = outputs
286+
else:
287+
seq_outputs = [outputs]
247288

248-
inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys())
249-
new_inputs = [vectorize.get(inp, inp) for inp in inputs]
289+
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
290+
new_inputs = [replace.get(inp, inp) for inp in inputs]
250291

251-
def transform(var):
292+
def transform(var: Variable) -> Variable:
252293
if var in inputs:
253294
return new_inputs[inputs.index(var)]
254295

@@ -257,7 +298,13 @@ def transform(var):
257298
batched_node = vectorize_node(node, *batched_inputs)
258299
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
259300

260-
return batched_var
301+
return cast(Variable, batched_var)
261302

262303
# TODO: MergeOptimization or node caching?
263-
return [transform(out) for out in outputs]
304+
seq_vect_outputs = [transform(out) for out in seq_outputs]
305+
306+
if isinstance(outputs, Sequence):
307+
return seq_vect_outputs
308+
else:
309+
[vect_output] = seq_vect_outputs
310+
return vect_output

pytensor/tensor/blockwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def as_core(t, core_t):
275275

276276
igrads = vectorize(
277277
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
278-
vectorize=dict(
278+
replace=dict(
279279
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
280280
),
281281
)

tests/graph/test_replace.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytensor.tensor as pt
66
from pytensor import config, function, shared
7-
from pytensor.graph.basic import graph_inputs
7+
from pytensor.graph.basic import equal_computations, graph_inputs
88
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
99
from pytensor.tensor import dvector, fvector, vector
1010
from tests import unittest_tools as utt
@@ -236,9 +236,27 @@ def test_basic(self):
236236
new_x = pt.matrix("new_x")
237237
[new_y] = vectorize([y], {x: new_x})
238238

239+
# Check we can pass both a sequence or a single variable
240+
alt_new_y = vectorize(y, {x: new_x})
241+
assert equal_computations([new_y], [alt_new_y])
242+
239243
fn = function([new_x], new_y)
240244
test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX)
241245
np.testing.assert_allclose(
242246
fn(test_new_y),
243247
scipy.special.softmax(test_new_y, axis=-1),
244248
)
249+
250+
def test_multiple_outputs(self):
251+
x = pt.vector("x")
252+
y1 = x[0]
253+
y2 = x[-1]
254+
255+
new_x = pt.matrix("new_x")
256+
[new_y1, new_y2] = vectorize([y1, y2], {x: new_x})
257+
258+
fn = function([new_x], [new_y1, new_y2])
259+
new_x_test = np.arange(9).reshape(3, 3).astype(config.floatX)
260+
new_y1_res, new_y2_res = fn(new_x_test)
261+
np.testing.assert_allclose(new_y1_res, [0, 3, 6])
262+
np.testing.assert_allclose(new_y2_res, [2, 5, 8])

0 commit comments

Comments
 (0)