Skip to content

Commit bf72fa6

Browse files
authored
Merge branch 'master' into master
2 parents 2950305 + 9ab9a29 commit bf72fa6

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

brainpy/_src/math/object_transform/controls.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -906,11 +906,10 @@ def _body_fun(op):
906906
new_vals = body_fun(*old_vals)
907907
if new_vals is None:
908908
new_vals = old_vals
909-
else:
910-
if isinstance(new_vals, (list, tuple)):
911-
new_vals = tuple(new_vals)
912-
else:
913-
new_vals = (new_vals,)
909+
if not isinstance(new_vals, tuple):
910+
new_vals = (new_vals,)
911+
if isinstance(new_vals, list):
912+
new_vals = tuple(new_vals)
914913
return dyn_vars.dict_data(), new_vals
915914

916915
def _cond_fun(op):

brainpy/_src/math/object_transform/tests/test_controls.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,26 @@ def body(x, y):
205205
print()
206206
print(res)
207207

208+
def test3(self):
209+
a = bm.Variable(bm.zeros(1))
210+
b = bm.Variable(bm.ones(1))
211+
212+
def cond(x, y):
213+
return bm.all(a.value < 6.)
214+
215+
def body(x, y):
216+
a.value += x
217+
b.value *= y
218+
219+
res = bm.while_loop(body, cond, operands=(1., 1.))
220+
self.assertTrue(bm.allclose(a, 6.))
221+
self.assertTrue(bm.allclose(b, 1.))
222+
print()
223+
print(res)
224+
print(a)
225+
print(b)
226+
227+
208228
def test2(self):
209229
a = bm.Variable(bm.zeros(1))
210230
b = bm.Variable(bm.ones(1))

0 commit comments

Comments
 (0)