|
5 | 5 | import pytensor.tensor as pt
|
6 | 6 | from pytensor import config, function, shared
|
7 | 7 | from pytensor.graph.basic import equal_computations, graph_inputs
|
8 |
| -from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph |
| 8 | +from pytensor.graph.replace import ( |
| 9 | + clone_replace, |
| 10 | + graph_replace, |
| 11 | + vectorize_graph, |
| 12 | + vectorize_node, |
| 13 | +) |
9 | 14 | from pytensor.tensor import dvector, fvector, vector
|
10 | 15 | from tests import unittest_tools as utt
|
11 |
| -from tests.graph.utils import MyOp, MyVariable |
| 16 | +from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs |
12 | 17 |
|
13 | 18 |
|
14 | 19 | class TestCloneReplace:
|
@@ -227,8 +232,6 @@ def test_graph_replace_disconnected(self):
|
227 | 232 |
|
228 | 233 |
|
229 | 234 | class TestVectorizeGraph:
|
230 |
| - # TODO: Add tests with multiple outputs, constants, and other singleton types |
231 |
| - |
232 | 235 | def test_basic(self):
|
233 | 236 | x = pt.vector("x")
|
234 | 237 | y = pt.exp(x) / pt.sum(pt.exp(x))
|
@@ -260,3 +263,63 @@ def test_multiple_outputs(self):
|
260 | 263 | new_y1_res, new_y2_res = fn(new_x_test)
|
261 | 264 | np.testing.assert_allclose(new_y1_res, [0, 3, 6])
|
262 | 265 | np.testing.assert_allclose(new_y2_res, [2, 5, 8])
|
| 266 | + |
| 267 | + def test_multi_output_node(self): |
| 268 | + x = pt.scalar("x") |
| 269 | + node = op_multiple_outputs.make_node(x) |
| 270 | + y1, y2 = node.outputs |
| 271 | + out = pt.add(y1, y2) |
| 272 | + |
| 273 | + new_x = pt.vector("new_x") |
| 274 | + new_y1 = pt.vector("new_y1") |
| 275 | + new_y2 = pt.vector("new_y2") |
| 276 | + |
| 277 | + # Cases where either x or both of y1 and y2 are given replacements |
| 278 | + new_out = vectorize_graph(out, {x: new_x}) |
| 279 | + expected_new_out = pt.add(*vectorize_node(node, new_x).outputs) |
| 280 | + assert equal_computations([new_out], [expected_new_out]) |
| 281 | + |
| 282 | + new_out = vectorize_graph(out, {y1: new_y1, y2: new_y2}) |
| 283 | + expected_new_out = pt.add(new_y1, new_y2) |
| 284 | + assert equal_computations([new_out], [expected_new_out]) |
| 285 | + |
| 286 | + new_out = vectorize_graph(out, {x: new_x, y1: new_y1, y2: new_y2}) |
| 287 | + expected_new_out = pt.add(new_y1, new_y2) |
| 288 | + assert equal_computations([new_out], [expected_new_out]) |
| 289 | + |
| 290 | + # Special case where x is given a replacement as well as only one of y1 and y2 |
| 291 | + # The graph combines the replaced variable with the other vectorized output |
| 292 | + new_out = vectorize_graph(out, {x: new_x, y1: new_y1}) |
| 293 | + expected_new_out = pt.add(new_y1, vectorize_node(node, new_x).outputs[1]) |
| 294 | + assert equal_computations([new_out], [expected_new_out]) |
| 295 | + |
| 296 | + def test_multi_output_node_random_variable(self): |
| 297 | + """This is a regression test for #569. |
| 298 | +
|
| 299 | + Functionally, it covers the same case as `test_multiple_output_node` |
| 300 | + """ |
| 301 | + |
| 302 | + # RandomVariables have two outputs, a hidden RNG and the visible draws |
| 303 | + beta0 = pt.random.normal(name="beta0") |
| 304 | + beta1 = pt.random.normal(name="beta1") |
| 305 | + |
| 306 | + out1 = beta0 + 1 |
| 307 | + out2 = beta1 * pt.exp(out1) |
| 308 | + |
| 309 | + # We replace the second output of each RandomVariable |
| 310 | + new_beta0 = pt.tensor("new_beta0", shape=(3,)) |
| 311 | + new_beta1 = pt.tensor("new_beta1", shape=(3,)) |
| 312 | + |
| 313 | + new_outs = vectorize_graph( |
| 314 | + [out1, out2], |
| 315 | + replace={ |
| 316 | + beta0: new_beta0, |
| 317 | + beta1: new_beta1, |
| 318 | + }, |
| 319 | + ) |
| 320 | + |
| 321 | + expected_new_outs = [ |
| 322 | + new_beta0 + 1, |
| 323 | + new_beta1 * pt.exp(new_beta0 + 1), |
| 324 | + ] |
| 325 | + assert equal_computations(new_outs, expected_new_outs) |
0 commit comments