|
20 | 20 | from pytensor.scalar import float64, int64 |
21 | 21 | from pytensor.scalar.loop import ScalarLoop |
22 | 22 | from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus |
| 23 | +from pytensor.tensor.elemwise import Elemwise |
23 | 24 | from pytensor.tensor.type import matrices, matrix, scalar, vector |
24 | 25 |
|
25 | 26 |
|
@@ -420,3 +421,33 @@ def test_ScalarLoop_while(): |
420 | 421 | ): |
421 | 422 | np.testing.assert_allclose(res[0], np.array(expected[0])) |
422 | 423 | np.testing.assert_allclose(res[1], np.array(expected[1])) |
| 424 | + |
| 425 | +def test_pytorch_OpFromGraph(): |
| 426 | + x, y, z = matrices("xyz") |
| 427 | + ofg_1 = OpFromGraph([x, y], [x + y]) |
| 428 | + ofg_2 = OpFromGraph([x, y], [x * y, x - y]) |
| 429 | + |
| 430 | + o1, o2 = ofg_2(y, z) |
| 431 | + out = ofg_1(x, o1) + o2 |
| 432 | + |
| 433 | + xv = np.ones((2, 2), dtype=config.floatX) |
| 434 | + yv = np.ones((2, 2), dtype=config.floatX) * 3 |
| 435 | + zv = np.ones((2, 2), dtype=config.floatX) * 5 |
| 436 | + |
| 437 | + f = FunctionGraph([x, y, z], [out]) |
| 438 | + compare_pytorch_and_py(f, [xv, yv, zv]) |
| 439 | + |
| 440 | + |
| 441 | +def test_ScalarLoop_Elemwise(): |
| 442 | + n_steps = int64("n_steps") |
| 443 | + x0 = float64("x0") |
| 444 | + x = x0 * 2 |
| 445 | + until = x >= 10 |
| 446 | + |
| 447 | + op = ScalarLoop(init=[x0], update=[x], until=until) |
| 448 | + fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode) |
| 449 | + |
| 450 | + states, dones = fn(10, np.array(range(5))) |
| 451 | + |
| 452 | + np.testing.assert_allclose(states, [0, 4, 8, 12, 16]) |
| 453 | + np.testing.assert_allclose(dones, [False, False, False, True, True]) |
0 commit comments