Skip to content

Commit 5934e09

Browse files
committed
Add for loop based scalar loop
1 parent 23427a0 commit 5934e09

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import torch
22

33
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4-
from pytensor.scalar.basic import (
5-
ScalarOp,
6-
)
4+
from pytensor.scalar.basic import ScalarOp
5+
from pytensor.scalar.loop import ScalarLoop
76

87

98
@pytorch_funcify.register(ScalarOp)
@@ -38,3 +37,24 @@ def pytorch_func(*args):
3837
)
3938

4039
return pytorch_func
40+
41+
42+
@pytorch_funcify.register(ScalarLoop)
43+
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
44+
update = pytorch_funcify(op.fgraph)
45+
46+
def inner(steps, start, constant, update=update, is_while=op.is_while):
47+
# easiest way to do it is to loop
48+
c = start
49+
for i in range(steps):
50+
outs = update(c, constant)
51+
if is_while:
52+
n, done = outs
53+
if done:
54+
return n
55+
c = n
56+
else:
57+
c = outs[0]
58+
return c
59+
60+
return inner

tests/link/pytorch/test_basic.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.op import Op
1515
from pytensor.raise_op import CheckAndRaise
16+
from pytensor.scalar import float64, int64
17+
from pytensor.scalar.loop import ScalarLoop
1618
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
1719
from pytensor.tensor.type import matrix, scalar, vector
1820

@@ -301,3 +303,18 @@ def test_pytorch_MakeVector():
301303
x_fg = FunctionGraph([], [x])
302304

303305
compare_pytorch_and_py(x_fg, [])
306+
307+
308+
def test_ScalarLoop():
309+
n_steps = int64("n_steps")
310+
x0 = float64("x0")
311+
const = float64("const")
312+
x = x0 + const
313+
314+
op = ScalarLoop(init=[x0], constant=[const], update=[x])
315+
x = op(n_steps, x0, const)
316+
317+
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
318+
np.testing.assert_allclose(fn(5, 0, 1), 5)
319+
np.testing.assert_allclose(fn(5, 0, 2), 10)
320+
np.testing.assert_allclose(fn(4, 3, -1), -1)

0 commit comments

Comments
 (0)