Skip to content

Commit ec88252

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add for loop based scalar loop
1 parent ae66e82 commit ec88252

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Cast,
66
ScalarOp,
77
)
8+
from pytensor.scalar.loop import ScalarLoop
89

910

1011
@pytorch_funcify.register(ScalarOp)
@@ -49,3 +50,23 @@ def cast(x):
4950
return x.to(dtype=dtype)
5051

5152
return cast
53+
54+
@pytorch_funcify.register(ScalarLoop)
55+
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
56+
update = pytorch_funcify(op.fgraph)
57+
58+
def inner(steps, start, constant, update=update, is_while=op.is_while):
59+
# easiest way to do it is to loop
60+
c = start
61+
for i in range(steps):
62+
outs = update(c, constant)
63+
if is_while:
64+
n, done = outs
65+
if done:
66+
return n
67+
c = n
68+
else:
69+
c = outs[0]
70+
return c
71+
72+
return inner

tests/link/pytorch/test_basic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from pytensor.ifelse import ifelse
1818
from pytensor.link.pytorch.linker import PytorchLinker
1919
from pytensor.raise_op import CheckAndRaise
20+
from pytensor.scalar import float64, int64
21+
from pytensor.scalar.loop import ScalarLoop
2022
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
2123
from pytensor.tensor.type import matrices, matrix, scalar, vector
2224

@@ -374,3 +376,18 @@ def inner_fn(x):
374376
f = function([x], out, mode="PYTORCH")
375377
f(torch.ones(3))
376378
assert "inner_fn" not in dir(m), "function call reference leaked"
379+
380+
381+
def test_ScalarLoop():
382+
n_steps = int64("n_steps")
383+
x0 = float64("x0")
384+
const = float64("const")
385+
x = x0 + const
386+
387+
op = ScalarLoop(init=[x0], constant=[const], update=[x])
388+
x = op(n_steps, x0, const)
389+
390+
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
391+
np.testing.assert_allclose(fn(5, 0, 1), 5)
392+
np.testing.assert_allclose(fn(5, 0, 2), 10)
393+
np.testing.assert_allclose(fn(4, 3, -1), -1)

0 commit comments

Comments
 (0)