@@ -213,9 +213,26 @@ def vectorize_node(node: Apply, *batched_inputs) -> Apply:
213
213
return _vectorize_node (op , node , * batched_inputs )
214
214
215
215
216
+ @overload
217
+ def vectorize (
218
+ outputs : Variable ,
219
+ replace : Mapping [Variable , Variable ],
220
+ ) -> Variable :
221
+ ...
222
+
223
+
224
+ @overload
216
225
def vectorize (
217
- outputs : Sequence [Variable ], vectorize : Mapping [Variable , Variable ]
226
+ outputs : Sequence [Variable ],
227
+ replace : Mapping [Variable , Variable ],
218
228
) -> Sequence [Variable ]:
229
+ ...
230
+
231
+
232
+ def vectorize (
233
+ outputs : Union [Variable , Sequence [Variable ]],
234
+ replace : Mapping [Variable , Variable ],
235
+ ) -> Union [Variable , Sequence [Variable ]]:
219
236
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
220
237
221
238
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
@@ -235,20 +252,44 @@ def vectorize(
235
252
236
253
# Vectorized graph
237
254
new_x = pt.matrix("new_x")
238
- [ new_y] = vectorize([y], {x: new_x})
255
+ new_y = vectorize(y, replace= {x: new_x})
239
256
240
257
fn = pytensor.function([new_x], new_y)
241
258
fn([[0, 1, 2], [2, 1, 0]])
242
259
# array([[0.09003057, 0.24472847, 0.66524096],
243
260
# [0.66524096, 0.24472847, 0.09003057]])
244
261
262
+
263
+ .. code-block:: python
264
+
265
+ import pytensor
266
+ import pytensor.tensor as pt
267
+
268
+ from pytensor.graph import vectorize
269
+
270
+ # Original graph
271
+ x = pt.vector("x")
272
+ y1 = x[0]
273
+ y2 = x[-1]
274
+
275
+ # Vectorized graph
276
+ new_x = pt.matrix("new_x")
277
+ [new_y1, new_y2] = vectorize([y1, y2], replace={x: new_x})
278
+
279
+ fn = pytensor.function([new_x], [new_y1, new_y2])
280
+ fn([[-10, 0, 10], [-11, 0, 11]])
281
+ # [array([-10., -11.]), array([10., 11.])]
282
+
245
283
"""
246
- # Avoid circular import
284
+ if isinstance (outputs , Sequence ):
285
+ seq_outputs = outputs
286
+ else :
287
+ seq_outputs = [outputs ]
247
288
248
- inputs = truncated_graph_inputs (outputs , ancestors_to_include = vectorize .keys ())
249
- new_inputs = [vectorize .get (inp , inp ) for inp in inputs ]
289
+ inputs = truncated_graph_inputs (seq_outputs , ancestors_to_include = replace .keys ())
290
+ new_inputs = [replace .get (inp , inp ) for inp in inputs ]
250
291
251
- def transform (var ) :
292
+ def transform (var : Variable ) -> Variable :
252
293
if var in inputs :
253
294
return new_inputs [inputs .index (var )]
254
295
@@ -257,7 +298,13 @@ def transform(var):
257
298
batched_node = vectorize_node (node , * batched_inputs )
258
299
batched_var = batched_node .outputs [var .owner .outputs .index (var )]
259
300
260
- return batched_var
301
+ return cast ( Variable , batched_var )
261
302
262
303
# TODO: MergeOptimization or node caching?
263
- return [transform (out ) for out in outputs ]
304
+ seq_vect_outputs = [transform (out ) for out in seq_outputs ]
305
+
306
+ if isinstance (outputs , Sequence ):
307
+ return seq_vect_outputs
308
+ else :
309
+ [vect_output ] = seq_vect_outputs
310
+ return vect_output
0 commit comments