Skip to content

Commit ba72bba

Browse files
committed
[oo transform] fix bugs on OO-transform changes
1 parent d5ea988 commit ba72bba

File tree

8 files changed

+53
-43
lines changed

8 files changed

+53
-43
lines changed

brainpy/_src/analysis/lowdim/tests/test_phase_plane.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import jax.numpy as jnp
88

99

10-
block = False
10+
show = False
1111

1212

1313
class TestPhasePlane(unittest.TestCase):
@@ -27,7 +27,8 @@ def int_x(x, t, Iext):
2727
plt.ion()
2828
analyzer.plot_vector_field()
2929
analyzer.plot_fixed_point()
30-
plt.show(block=block)
30+
if show:
31+
plt.show()
3132
plt.close()
3233
bp.math.disable_x64()
3334

@@ -74,6 +75,7 @@ def int_s2(s2, t, s1):
7475
analyzer.plot_vector_field()
7576
analyzer.plot_nullcline(coords=dict(s2='s2-s1'))
7677
analyzer.plot_fixed_point()
77-
plt.show(block=block)
78+
if show:
79+
plt.show()
7880
plt.close()
7981
bp.math.disable_x64()

brainpy/_src/dyn/runners.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -668,10 +668,6 @@ def _get_f_predict(self, shared_args: Dict = None):
668668

669669
shared_kwargs_str = serialize_kwargs(shared_args)
670670
if shared_kwargs_str not in self._f_predict_compiled:
671-
dyn_vars = self.target.vars()
672-
dyn_vars.update(self._dyn_vars)
673-
dyn_vars.update(self.vars(level=0))
674-
dyn_vars = dyn_vars.unique()
675671

676672
if self._memory_efficient:
677673
_jit_step = bm.jit(partial(self._step_func_predict, shared_args))
@@ -688,11 +684,10 @@ def run_func(all_inputs):
688684
return outs, None
689685

690686
else:
691-
@bm.jit
687+
step = partial(self._step_func_predict, shared_args)
688+
692689
def run_func(all_inputs):
693-
return bm.for_loop(partial(self._step_func_predict, shared_args),
694-
all_inputs,
695-
jit=self.jit['predict'])
690+
return bm.for_loop(step, all_inputs, jit=self.jit['predict'])
696691

697692
self._f_predict_compiled[shared_kwargs_str] = run_func
698693

brainpy/_src/integrators/sde/normal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def _get_g_grad(self, f, allow_raise=False, need_grad=True):
292292
if not allow_raise:
293293
raise e
294294
if need_grad:
295-
res[0] = bm.vector_grad(f, argnums=0, dyn_vars=self.dyn_vars)
295+
res[0] = bm.vector_grad(f, argnums=0)
296296
return [tuple(res)], state
297297

298298
def step(self, *args, **kwargs):

brainpy/_src/integrators/sde/tests/test_normal.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import matplotlib.pyplot as plt
88
from brainpy._src.integrators.sde.normal import ExponentialEuler
99

10-
block = False
10+
show = False
1111

1212

1313
class TestExpEuler(unittest.TestCase):
@@ -33,16 +33,18 @@ def lorenz_g(x, y, z, t, **kwargs):
3333
runner.run(100.)
3434

3535
plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten())
36-
plt.show(block=block)
36+
if show:
37+
plt.show()
38+
plt.close()
3739

3840
def test2(self):
3941
p = 0.1
4042
p2 = 0.02
4143

4244
def lorenz_g(x, y, z, t, **kwargs):
4345
return bp.math.asarray([p * x, p2 * x]), \
44-
bp.math.asarray([p * y, p2 * y]), \
45-
bp.math.asarray([p * z, p2 * z])
46+
bp.math.asarray([p * y, p2 * y]), \
47+
bp.math.asarray([p * z, p2 * z])
4648

4749
dx = lambda x, t, y, sigma=10: sigma * (y - x)
4850
dy = lambda y, t, x, z, rho=28: x * (rho - z) - y
@@ -54,8 +56,8 @@ def lorenz_g(x, y, z, t, **kwargs):
5456
wiener_type=bp.integrators.VECTOR_WIENER,
5557
var_type=bp.integrators.POP_VAR,
5658
show_code=True)
57-
runner = bp.integrators.IntegratorRunner(intg, monitors=['x', 'y', 'z'],
58-
dt=0.001, inits=[1., 1., 0.], jit=False)
59+
runner = bp.IntegratorRunner(intg, monitors=['x', 'y', 'z'],
60+
dt=0.001, inits=[1., 1., 0.], jit=False)
5961
with self.assertRaises(ValueError):
6062
runner.run(100.)
6163

@@ -65,8 +67,8 @@ def test3(self):
6567

6668
def lorenz_g(x, y, z, t, **kwargs):
6769
return bp.math.asarray([p * x, p2 * x]).T, \
68-
bp.math.asarray([p * y, p2 * y]).T, \
69-
bp.math.asarray([p * z, p2 * z]).T
70+
bp.math.asarray([p * y, p2 * y]).T, \
71+
bp.math.asarray([p * z, p2 * z]).T
7072

7173
dx = lambda x, t, y, sigma=10: sigma * (y - x)
7274
dy = lambda y, t, x, z, rho=28: x * (rho - z) - y
@@ -78,15 +80,17 @@ def lorenz_g(x, y, z, t, **kwargs):
7880
wiener_type=bp.integrators.VECTOR_WIENER,
7981
var_type=bp.integrators.POP_VAR,
8082
show_code=True)
81-
runner = bp.integrators.IntegratorRunner(intg,
82-
monitors=['x', 'y', 'z'],
83-
dt=0.001,
84-
inits=[1., 1., 0.],
85-
jit=True)
83+
runner = bp.IntegratorRunner(intg,
84+
monitors=['x', 'y', 'z'],
85+
dt=0.001,
86+
inits=[1., 1., 0.],
87+
jit=True)
8688
runner.run(100.)
8789

8890
plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten())
89-
plt.show(block=block)
91+
if show:
92+
plt.show()
93+
plt.close()
9094

9195

9296
class TestMilstein(unittest.TestCase):
@@ -110,11 +114,14 @@ def test1(self):
110114
wiener_type=bp.integrators.SCALAR_WIENER,
111115
var_type=bp.integrators.POP_VAR,
112116
method='milstein')
113-
runner = bp.integrators.IntegratorRunner(intg,
114-
monitors=['x', 'y', 'z'],
115-
dt=0.001, inits=[1., 1., 0.],
116-
jit=True)
117+
runner = bp.IntegratorRunner(intg,
118+
monitors=['x', 'y', 'z'],
119+
dt=0.001, inits=[1., 1., 0.],
120+
jit=True)
117121
runner.run(100.)
118122

119123
plt.plot(runner.mon.x.flatten(), runner.mon.y.flatten())
120-
plt.show(block=block)
124+
if show:
125+
plt.show()
126+
plt.close()
127+

brainpy/_src/math/object_transform/_tools.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ class Empty(object):
1616
empty = Empty()
1717

1818

19-
def _partial_fun(fun, args, kwargs,
19+
def _partial_fun(fun,
20+
args: tuple,
21+
kwargs: dict,
2022
static_argnums: Sequence[int] = (),
2123
static_argnames: Sequence[str] = ()):
2224
static_args, dyn_args = [], []
@@ -35,16 +37,16 @@ def _partial_fun(fun, args, kwargs,
3537
del args, kwargs, static_argnums, static_argnames
3638

3739
@wraps(fun)
38-
def new_fun(*dyn_args, **dyn_kwargs):
40+
def new_fun(*dynargs, **dynkwargs):
3941
args = []
4042
i = 0
4143
for arg in static_args:
4244
if arg == empty:
43-
args.append(dyn_args[i])
45+
args.append(dynargs[i])
4446
i += 1
4547
else:
4648
args.append(arg)
47-
return fun(*args, **static_kwargs, **dyn_kwargs)
49+
return fun(*args, **static_kwargs, **dynkwargs)
4850

4951
return new_fun, dyn_args, dyn_kwargs
5052

@@ -80,14 +82,16 @@ def evaluate_dyn_vars(f,
8082
static_argnames: Sequence[str] = (),
8183
**kwargs):
8284
# TODO: better way for cache mechanism
83-
if len(static_argnums) or len(static_argnames):
84-
f, args, kwargs = _partial_fun(f, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
8585
stack = get_stack_cache(f)
8686
if stack is None:
87+
if len(static_argnums) or len(static_argnames):
88+
f2, args, kwargs = _partial_fun(f, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
89+
else:
90+
f2, args, kwargs = f, args, kwargs
8791
with jax.ensure_compile_time_eval():
8892
with VariableStack() as stack:
89-
_ = jax.eval_shape(f, *args, *kwargs)
93+
_ = jax.eval_shape(f2, *args, **kwargs)
9094
cache_stack(f, stack) # cache
91-
del args, kwargs
95+
del args, kwargs, f2
9296
return stack
9397

brainpy/_src/math/object_transform/jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def __call__(self, *args, **kwargs):
106106
return self.fun(*args, **kwargs)
107107

108108
if self._transform is None:
109-
self._dyn_vars = evaluate_dyn_vars(self.fun, *args,
109+
self._dyn_vars = evaluate_dyn_vars(self.fun,
110+
*args,
110111
static_argnums=self._static_argnums,
111112
static_argnames=self._static_argnames,
112113
**kwargs)

brainpy/_src/math/object_transform/tests/test_autograd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,10 @@ def __call__(self):
100100
bm.random.seed(0)
101101

102102
t = Test()
103-
f_grad = bm.grad(t)
103+
f_grad = bm.grad(t, grad_vars={'a': t.a, 'b': t.b, 'c': t.c})
104104
grads = f_grad()
105-
for g in grads.values(): assert (g == 1.).all()
105+
for g in grads.values():
106+
assert (g == 1.).all()
106107

107108
t = Test()
108109
f_grad = bm.grad(t, grad_vars=[t.a, t.b])

brainpy/_src/math/object_transform/tests/test_controls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import brainpy.math as bm
1212

1313

14-
class TestLoop(jtu.JaxTestCase):
14+
class TestLoop(parameterized.TestCase):
1515
def test_make_loop(self):
1616
def make_node(v1, v2):
1717
def update(x):

0 commit comments

Comments
 (0)