Skip to content

Commit 7804b90

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add for loop based scalar loop
1 parent 0824dba commit 7804b90

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ScalarOp,
99
)
1010
from pytensor.scalar.math import Softplus
11+
from pytensor.scalar.loop import ScalarLoop
1112

1213

1314
@pytorch_funcify.register(ScalarOp)
@@ -58,7 +59,26 @@ def cast(x):
5859

5960
return cast
6061

61-
6262
@pytorch_funcify.register(Softplus)
6363
def pytorch_funcify_Softplus(op, node, **kwargs):
6464
return torch.nn.Softplus()
65+
66+
@pytorch_funcify.register(ScalarLoop)
67+
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
68+
update = pytorch_funcify(op.fgraph)
69+
70+
def inner(steps, start, constant, update=update, is_while=op.is_while):
71+
# easiest way to do it is to loop
72+
c = start
73+
for i in range(steps):
74+
outs = update(c, constant)
75+
if is_while:
76+
n, done = outs
77+
if done:
78+
return n
79+
c = n
80+
else:
81+
c = outs[0]
82+
return c
83+
84+
return inner

tests/link/pytorch/test_basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from pytensor.ifelse import ifelse
1818
from pytensor.link.pytorch.linker import PytorchLinker
1919
from pytensor.raise_op import CheckAndRaise
20+
from pytensor.scalar import float64, int64
21+
from pytensor.scalar.loop import ScalarLoop
2022
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
2123
from pytensor.tensor.type import matrices, matrix, scalar, vector
2224

@@ -388,3 +390,17 @@ def test_pytorch_softplus():
388390
out = softplus(x)
389391
f = FunctionGraph([x], [out])
390392
compare_pytorch_and_py(f, [np.random.rand(3)])
393+
394+
def test_ScalarLoop():
395+
n_steps = int64("n_steps")
396+
x0 = float64("x0")
397+
const = float64("const")
398+
x = x0 + const
399+
400+
op = ScalarLoop(init=[x0], constant=[const], update=[x])
401+
x = op(n_steps, x0, const)
402+
403+
fn = function([n_steps, x0, const], x, mode=pytorch_mode)
404+
np.testing.assert_allclose(fn(5, 0, 1), 5)
405+
np.testing.assert_allclose(fn(5, 0, 2), 10)
406+
np.testing.assert_allclose(fn(4, 3, -1), -1)

0 commit comments

Comments
 (0)