Skip to content

Commit 16cf74a

Browse files
authored
[math] fix brainpy.math.scan (#604)
1 parent 7e8dd81 commit 16cf74a

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

brainpy/_src/math/object_transform/controls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,8 @@ def scan(
940940
):
941941
"""``scan`` control flow with :py:class:`~.Variable`.
942942
943+
Similar to ``jax.lax.scan``.
944+
943945
.. versionadded:: 2.4.7
944946
945947
All returns in body function will be gathered
@@ -999,7 +1001,7 @@ def scan(
9991001
rets = jax.eval_shape(transform, init, operands)
10001002
cache_stack(body_fun, dyn_vars) # cache
10011003
if current_transform_number():
1002-
return rets[1]
1004+
return rets[0][1], rets[1]
10031005
del rets
10041006

10051007
transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll)

brainpy/_src/math/object_transform/tests/test_controls.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
# -*- coding: utf-8 -*-
2-
import sys
32
import tempfile
43
import unittest
54
from functools import partial
65

76
import jax
8-
from jax import vmap
9-
107
from absl.testing import parameterized
11-
from jax._src import test_util as jtu
8+
from jax import vmap
129

1310
import brainpy as bp
1411
import brainpy.math as bm
@@ -147,6 +144,25 @@ def f(carray, x):
147144
expected = bm.expand_dims(expected, axis=-1)
148145
self.assertTrue(bm.allclose(outs, expected))
149146

147+
def test2(self):
148+
a = bm.Variable(1)
149+
150+
def f(carray, x):
151+
carray += x
152+
a.value += 1.
153+
return carray, a
154+
155+
@bm.jit
156+
def f_outer(carray, x):
157+
carry, outs = bm.scan(f, carray, x, unroll=2)
158+
return carry, outs
159+
160+
carry, outs = f_outer(bm.zeros(2), bm.arange(10))
161+
self.assertTrue(bm.allclose(carry, 45.))
162+
expected = bm.arange(1, 11).astype(outs.dtype)
163+
expected = bm.expand_dims(expected, axis=-1)
164+
self.assertTrue(bm.allclose(outs, expected))
165+
150166

151167
class TestCond(unittest.TestCase):
152168
def test1(self):
@@ -234,7 +250,6 @@ def F2(x):
234250
self.assertTrue(bm.grad(F2)(9.0) == 18.)
235251
self.assertTrue(bm.grad(F2)(11.0) == 1.)
236252

237-
238253
def test_grad2(self):
239254
def F3(x):
240255
return bm.ifelse(conditions=(x >= 10, x >= 0),
@@ -519,6 +534,3 @@ def body(a):
519534
file.seek(0)
520535
out6 = file.read().strip()
521536
self.assertTrue(out5 == out6)
522-
523-
524-

docs/apis/brainpy.math.oo_transform.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Object-oriented Transformations
6060
ifelse
6161
for_loop
6262
while_loop
63+
scan
6364
jit
6465
cls_jit
6566
to_object

0 commit comments

Comments
 (0)