Skip to content

Commit 5874e8d

Browse files
authored
New OO transforms support jax.disable_jit mode (#359)
New OO transforms support ``jax.disable_jit`` mode
2 parents df1897b + 77cbd76 commit 5874e8d

File tree

10 files changed

+188
-106
lines changed

10 files changed

+188
-106
lines changed

brainpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22

3-
__version__ = "2.3.8"
3+
__version__ = "2.4.0"
44

55

66
# fundamental supporting modules

brainpy/_src/checkpoints/tests/test_io.py

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self, *args, **kwargs):
1212

1313
rng = bm.random.RandomState()
1414

15-
class IO1(bp.dyn.DynamicalSystem):
15+
class IO1(bp.DynamicalSystem):
1616
def __init__(self):
1717
super(IO1, self).__init__()
1818

@@ -21,7 +21,7 @@ def __init__(self):
2121
self.c = bm.Variable(bm.ones((3, 4)))
2222
self.d = bm.Variable(bm.ones((2, 3, 4)))
2323

24-
class IO2(bp.dyn.DynamicalSystem):
24+
class IO2(bp.DynamicalSystem):
2525
def __init__(self):
2626
super(IO2, self).__init__()
2727

@@ -35,59 +35,59 @@ def __init__(self):
3535
io2.a2 = io1.a
3636
io2.b2 = io2.b
3737

38-
self.net = bp.dyn.Container(io1, io2)
38+
self.net = bp.Container(io1, io2)
3939

4040
print(self.net.vars().keys())
4141
print(self.net.vars().unique().keys())
4242

4343
def test_h5(self):
44-
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
45-
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
44+
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
45+
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
4646

47-
bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
48-
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
47+
bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
48+
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
4949

5050
def test_h5_postfix(self):
5151
with self.assertRaises(ValueError):
52-
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
52+
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
5353
with self.assertRaises(ValueError):
54-
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
54+
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
5555

5656
def test_npz(self):
57-
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
58-
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
57+
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
58+
bp.checkpoints.io.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
5959

60-
bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
61-
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
60+
bp.checkpoints.io.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
61+
bp.checkpoints.io.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
6262

6363
def test_npz_postfix(self):
6464
with self.assertRaises(ValueError):
65-
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
65+
bp.checkpoints.io.save_as_npz('io_test_tmp.npz2', self.net.vars())
6666
with self.assertRaises(ValueError):
67-
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
67+
bp.checkpoints.io.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
6868

6969
def test_pkl(self):
70-
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
71-
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
70+
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl', self.net.vars())
71+
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
7272

73-
bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
74-
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
73+
bp.checkpoints.io.save_as_pkl('io_test_tmp.pickle', self.net.vars())
74+
bp.checkpoints.io.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
7575

7676
def test_pkl_postfix(self):
7777
with self.assertRaises(ValueError):
78-
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
78+
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
7979
with self.assertRaises(ValueError):
80-
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
80+
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
8181

8282
def test_mat(self):
83-
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
84-
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
83+
bp.checkpoints.io.save_as_mat('io_test_tmp.mat', self.net.vars())
84+
bp.checkpoints.io.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
8585

8686
def test_mat_postfix(self):
8787
with self.assertRaises(ValueError):
88-
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
88+
bp.checkpoints.io.save_as_mat('io_test_tmp.mat2', self.net.vars())
8989
with self.assertRaises(ValueError):
90-
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)
90+
bp.checkpoints.io.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)
9191

9292

9393
class TestIO2(unittest.TestCase):
@@ -96,7 +96,7 @@ def __init__(self, *args, **kwargs):
9696

9797
rng = bm.random.RandomState()
9898

99-
class IO1(bp.dyn.DynamicalSystem):
99+
class IO1(bp.DynamicalSystem):
100100
def __init__(self):
101101
super(IO1, self).__init__()
102102

@@ -105,7 +105,7 @@ def __init__(self):
105105
self.c = bm.Variable(bm.ones((3, 4)))
106106
self.d = bm.Variable(bm.ones((2, 3, 4)))
107107

108-
class IO2(bp.dyn.DynamicalSystem):
108+
class IO2(bp.DynamicalSystem):
109109
def __init__(self):
110110
super(IO2, self).__init__()
111111

@@ -115,56 +115,56 @@ def __init__(self):
115115
io1 = IO1()
116116
io2 = IO2()
117117

118-
self.net = bp.dyn.Container(io1, io2)
118+
self.net = bp.Container(io1, io2)
119119

120120
print(self.net.vars().keys())
121121
print(self.net.vars().unique().keys())
122122

123123
def test_h5(self):
124-
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
125-
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
124+
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
125+
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
126126

127-
bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
128-
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
127+
bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
128+
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
129129

130130
def test_h5_postfix(self):
131131
with self.assertRaises(ValueError):
132-
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
132+
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
133133
with self.assertRaises(ValueError):
134-
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
134+
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
135135

136136
def test_npz(self):
137-
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
138-
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
137+
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
138+
bp.checkpoints.io.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
139139

140-
bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
141-
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
140+
bp.checkpoints.io.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
141+
bp.checkpoints.io.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
142142

143143
def test_npz_postfix(self):
144144
with self.assertRaises(ValueError):
145-
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
145+
bp.checkpoints.io.save_as_npz('io_test_tmp.npz2', self.net.vars())
146146
with self.assertRaises(ValueError):
147-
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
147+
bp.checkpoints.io.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
148148

149149
def test_pkl(self):
150-
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
151-
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
150+
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl', self.net.vars())
151+
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
152152

153-
bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
154-
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
153+
bp.checkpoints.io.save_as_pkl('io_test_tmp.pickle', self.net.vars())
154+
bp.checkpoints.io.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
155155

156156
def test_pkl_postfix(self):
157157
with self.assertRaises(ValueError):
158-
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
158+
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
159159
with self.assertRaises(ValueError):
160-
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
160+
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
161161

162162
def test_mat(self):
163-
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
164-
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
163+
bp.checkpoints.io.save_as_mat('io_test_tmp.mat', self.net.vars())
164+
bp.checkpoints.io.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
165165

166166
def test_mat_postfix(self):
167167
with self.assertRaises(ValueError):
168-
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
168+
bp.checkpoints.io.save_as_mat('io_test_tmp.mat2', self.net.vars())
169169
with self.assertRaises(ValueError):
170-
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)
170+
bp.checkpoints.io.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)

brainpy/_src/math/object_transform/autograd.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
from brainpy import tools, check
1919
from brainpy._src.math.ndarray import Array
20-
from brainpy._src.math.object_transform.variables import Variable
21-
from brainpy._src.math.object_transform.base import BrainPyObject, ObjectTransform
22-
from brainpy._src.math.object_transform._tools import (dynvar_deprecation,
23-
node_deprecation,
24-
evaluate_dyn_vars)
20+
from .variables import Variable
21+
from .base import BrainPyObject, ObjectTransform
22+
from ._tools import (dynvar_deprecation,
23+
node_deprecation,
24+
evaluate_dyn_vars)
2525

2626
__all__ = [
2727
'grad', # gradient of scalar function

brainpy/_src/math/object_transform/controls.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
from brainpy import errors, tools, check
1313
from brainpy._src.math.interoperability import as_jax
1414
from brainpy._src.math.ndarray import (Array, )
15-
from brainpy._src.math.object_transform._tools import (evaluate_dyn_vars,
16-
dynvar_deprecation,
17-
node_deprecation,
18-
abstract)
19-
from brainpy._src.math.object_transform.variables import (Variable, VariableStack)
20-
from brainpy._src.math.object_transform.naming import (get_unique_name,
21-
get_stack_cache,
22-
cache_stack)
15+
from ._tools import (evaluate_dyn_vars,
16+
dynvar_deprecation,
17+
node_deprecation,
18+
abstract)
19+
from .variables import (Variable, VariableStack)
20+
from .naming import (get_unique_name,
21+
get_stack_cache,
22+
cache_stack)
2323
from ._utils import infer_dyn_vars
2424
from .base import BrainPyObject, ArrayCollector, ObjectTransform
2525

@@ -483,19 +483,15 @@ def cond(
483483
operands = (operands,)
484484

485485
# dyn vars
486-
if dyn_vars is None:
487-
dyn_vars = evaluate_dyn_vars(true_fun, *operands)
488-
dyn_vars += evaluate_dyn_vars(false_fun, *operands)
486+
dynvar_deprecation(dyn_vars)
487+
node_deprecation(child_objs)
488+
489+
if jax.config.jax_disable_jit:
490+
dyn_vars = VariableStack()
489491

490492
else:
491-
dynvar_deprecation(dyn_vars)
492-
node_deprecation(child_objs)
493-
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
494-
dyn_vars = ArrayCollector(dyn_vars)
495-
dyn_vars.update(infer_dyn_vars(true_fun))
496-
dyn_vars.update(infer_dyn_vars(false_fun))
497-
for obj in check.is_all_objs(child_objs, out_as='tuple'):
498-
dyn_vars.update(obj.vars().unique())
493+
dyn_vars = evaluate_dyn_vars(true_fun, *operands)
494+
dyn_vars += evaluate_dyn_vars(false_fun, *operands)
499495

500496
# TODO: cache mechanism?
501497
if len(dyn_vars) > 0:
@@ -746,14 +742,19 @@ def for_loop(
746742
if not isinstance(operands, (list, tuple)):
747743
operands = (operands,)
748744

749-
# TODO: better cache mechanism?
750745
dyn_vars = get_stack_cache(body_fun)
751-
if dyn_vars is None:
752-
with jax.ensure_compile_time_eval():
753-
op_vals = jax.tree_util.tree_map(_loop_abstractify, operands)
754-
with VariableStack() as dyn_vars:
755-
_ = jax.eval_shape(body_fun, *op_vals)
756-
cache_stack(body_fun, dyn_vars) # cache
746+
if not jit:
747+
if dyn_vars is None:
748+
dyn_vars = VariableStack()
749+
750+
else:
751+
# TODO: better cache mechanism?
752+
if dyn_vars is None:
753+
with jax.ensure_compile_time_eval():
754+
op_vals = jax.tree_util.tree_map(_loop_abstractify, operands)
755+
with VariableStack() as dyn_vars:
756+
_ = jax.eval_shape(body_fun, *op_vals)
757+
cache_stack(body_fun, dyn_vars) # cache
757758

758759
# functions
759760
def fun2scan(carry, x):
@@ -762,7 +763,8 @@ def fun2scan(carry, x):
762763
results = body_fun(*x)
763764
return dyn_vars.dict_data(), results
764765

765-
if remat: fun2scan = jax.checkpoint(fun2scan)
766+
if remat:
767+
fun2scan = jax.checkpoint(fun2scan)
766768

767769
# TODO: cache mechanism?
768770
with jax.disable_jit(not jit):
@@ -851,8 +853,12 @@ def while_loop(
851853
if not isinstance(operands, (list, tuple)):
852854
operands = (operands,)
853855

854-
dyn_vars = evaluate_dyn_vars(body_fun, *operands)
855-
dyn_vars += evaluate_dyn_vars(cond_fun, *operands)
856+
if jax.config.jax_disable_jit:
857+
dyn_vars = VariableStack()
858+
859+
else:
860+
dyn_vars = evaluate_dyn_vars(body_fun, *operands)
861+
dyn_vars += evaluate_dyn_vars(cond_fun, *operands)
856862

857863
def _body_fun(op):
858864
dyn_vals, old_vals = op

0 commit comments

Comments
 (0)