Skip to content

Commit bde7f8a

Browse files
authored
disable_ jit support in brainpy.math.scan (#606)
* [math] support disable jit in `brainpy.math.scan` * [math] support brainpy array in `cond`, `ifelse`, `scan` transformations * fix tests
1 parent 16cf74a commit bde7f8a

File tree

2 files changed

+66
-17
lines changed

2 files changed

+66
-17
lines changed

brainpy/_src/math/object_transform/controls.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
23
import functools
34
import numbers
45
from typing import Union, Sequence, Any, Dict, Callable, Optional
@@ -12,7 +13,7 @@
1213

1314
from brainpy import errors, tools
1415
from brainpy._src.math.interoperability import as_jax
15-
from brainpy._src.math.ndarray import (Array, )
16+
from brainpy._src.math.ndarray import (Array, _as_jax_array_)
1617
from .base import BrainPyObject, ObjectTransform
1718
from .naming import (
1819
get_unique_name,
@@ -421,11 +422,27 @@ def call(pred, x=None):
421422
return ControlObject(call, dyn_vars, repr_fun={'true_fun': true_fun, 'false_fun': false_fun})
422423

423424

425+
@functools.cache
426+
def _warp(f):
427+
@functools.wraps(f)
428+
def new_f(*args, **kwargs):
429+
return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array))
430+
431+
return new_f
432+
433+
434+
def _warp_data(data):
435+
def new_f(*args, **kwargs):
436+
return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array))
437+
438+
return new_f
439+
440+
424441
def _check_f(f):
425442
if callable(f):
426-
return f
443+
return _warp(f)
427444
else:
428-
return (lambda *args, **kwargs: f)
445+
return _warp_data(f)
429446

430447

431448
def _check_sequence(a):
@@ -557,7 +574,7 @@ def _if_else_return2(conditions, branches):
557574
return branches[-1]
558575

559576

560-
def all_equal(iterator):
577+
def _all_equal(iterator):
561578
iterator = iter(iterator)
562579
try:
563580
first = next(iterator)
@@ -671,7 +688,7 @@ def ifelse(
671688
else:
672689
rets = [jax.eval_shape(branch, *operands) for branch in branches]
673690
trees = [jax.tree_util.tree_structure(ret) for ret in rets]
674-
if not all_equal(trees):
691+
if not _all_equal(trees):
675692
msg = 'All returns in branches should have the same tree structure. But we got:\n'
676693
for tree in trees:
677694
msg += f'- {tree}\n'
@@ -914,12 +931,14 @@ def fun2scan(carry, x):
914931
carry, results = body_fun(carry, x)
915932
if progress_bar:
916933
id_tap(lambda *arg: bar.update(), ())
934+
carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array))
917935
return (dyn_vars.dict_data(), carry), results
918936

919937
if remat:
920938
fun2scan = jax.checkpoint(fun2scan)
921939

922940
def call(init, operands):
941+
init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array))
923942
return jax.lax.scan(f=fun2scan,
924943
init=(dyn_vars.dict_data(), init),
925944
xs=operands,
@@ -991,19 +1010,21 @@ def scan(
9911010
bar = tqdm(total=num_total)
9921011

9931012
dyn_vars = get_stack_cache(body_fun)
994-
if dyn_vars is None:
995-
with new_transform('scan'):
996-
with VariableStack() as dyn_vars:
997-
transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll)
998-
if current_transform_number() > 1:
999-
rets = transform(init, operands)
1000-
else:
1001-
rets = jax.eval_shape(transform, init, operands)
1002-
cache_stack(body_fun, dyn_vars) # cache
1003-
if current_transform_number():
1004-
return rets[0][1], rets[1]
1005-
del rets
1013+
if not jax.config.jax_disable_jit:
1014+
if dyn_vars is None:
1015+
with new_transform('scan'):
1016+
with VariableStack() as dyn_vars:
1017+
transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll)
1018+
if current_transform_number() > 1:
1019+
rets = transform(init, operands)
1020+
else:
1021+
rets = jax.eval_shape(transform, init, operands)
1022+
cache_stack(body_fun, dyn_vars) # cache
1023+
if current_transform_number():
1024+
return rets[0][1], rets[1]
1025+
del rets
10061026

1027+
dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
10071028
transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll)
10081029
(dyn_vals, carry), out_vals = transform(init, operands)
10091030
for key in dyn_vars.keys():

brainpy/_src/math/object_transform/tests/test_controls.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,34 @@ def f_outer(carray, x):
163163
expected = bm.expand_dims(expected, axis=-1)
164164
self.assertTrue(bm.allclose(outs, expected))
165165

166+
def test_disable_jit(self):
167+
def cumsum(res, el):
168+
res = res + el
169+
print(res)
170+
return res, res # ("carryover", "accumulated")
171+
172+
a = bm.array([1, 2, 3, 5, 7, 11, 13, 17]).value
173+
result_init = 0
174+
with jax.disable_jit():
175+
final, result = jax.lax.scan(cumsum, result_init, a)
176+
177+
b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
178+
result_init = 0
179+
with jax.disable_jit():
180+
final, result = bm.scan(cumsum, result_init, b)
181+
182+
bm.clear_buffer_memory()
183+
184+
def test_array_aware_of_bp_array(self):
185+
def cumsum(res, el):
186+
res = bm.asarray(res + el)
187+
return res, res # ("carryover", "accumulated")
188+
189+
b = bm.array([1, 2, 3, 5, 7, 11, 13, 17])
190+
result_init = 0
191+
with jax.disable_jit():
192+
final, result = bm.scan(cumsum, result_init, b)
193+
166194

167195
class TestCond(unittest.TestCase):
168196
def test1(self):

0 commit comments

Comments
 (0)