Skip to content

Commit 7693f14

Browse files
authored
JaxArray transformation context (#277)
JaxArray transformation context
2 parents 2bf7ba7 + a7e5053 commit 7693f14

File tree

5 files changed

+214
-90
lines changed

5 files changed

+214
-90
lines changed

brainpy/math/autograd.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from jax.util import safe_map
1717

1818
from brainpy import errors
19-
from brainpy.math.jaxarray import JaxArray
19+
from brainpy.base.naming import get_unique_name
20+
from brainpy.math.jaxarray import JaxArray, add_context, del_context
21+
2022

2123
__all__ = [
2224
'grad', # gradient of scalar function
@@ -28,20 +30,26 @@
2830

2931
def _make_cls_call_func(grad_func, grad_tree, grad_vars, dyn_vars,
3032
argnums, return_value, has_aux):
33+
name = get_unique_name('_brainpy_object_oriented_grad_')
34+
3135
# outputs
3236
def call_func(*args, **kwargs):
3337
old_grad_vs = [v.value for v in grad_vars]
3438
old_dyn_vs = [v.value for v in dyn_vars]
3539
try:
40+
add_context(name)
3641
grads, (outputs, new_grad_vs, new_dyn_vs) = grad_func(old_grad_vs,
3742
old_dyn_vs,
3843
*args,
3944
**kwargs)
45+
del_context(name)
4046
except UnexpectedTracerError as e:
47+
del_context(name)
4148
for v, d in zip(grad_vars, old_grad_vs): v._value = d
4249
for v, d in zip(dyn_vars, old_dyn_vs): v._value = d
4350
raise errors.JaxTracerError(variables=dyn_vars + grad_vars) from e
4451
except Exception as e:
52+
del_context(name)
4553
for v, d in zip(grad_vars, old_grad_vs): v._value = d
4654
for v, d in zip(dyn_vars, old_dyn_vs): v._value = d
4755
raise e

brainpy/math/controls.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from jax.core import UnexpectedTracerError
1414

1515
from brainpy import errors
16+
from brainpy.base.naming import get_unique_name
1617
from brainpy.math.jaxarray import (JaxArray, Variable,
17-
turn_on_global_jit,
18-
turn_off_global_jit)
18+
add_context,
19+
del_context)
1920
from brainpy.math.numpy_ops import as_device_array
2021

2122
__all__ = [
@@ -158,17 +159,19 @@ def make_loop(body_fun, dyn_vars, out_vars=None, has_return=False):
158159
out_vars=out_vars,
159160
has_return=has_return)
160161

162+
name = get_unique_name('_brainpy_object_oriented_make_loop_')
163+
161164
# functions
162165
if has_return:
163166
def call(xs=None, length=None):
164167
init_values = [v.value for v in dyn_vars]
165168
try:
166-
turn_on_global_jit()
169+
add_context(name)
167170
dyn_values, (out_values, results) = lax.scan(
168171
f=fun2scan, init=init_values, xs=xs, length=length)
169-
turn_off_global_jit()
172+
del_context(name)
170173
except UnexpectedTracerError as e:
171-
turn_off_global_jit()
174+
del_context(name)
172175
for v, d in zip(dyn_vars, init_values): v._value = d
173176
raise errors.JaxTracerError(variables=dyn_vars) from e
174177
for v, d in zip(dyn_vars, dyn_values): v._value = d
@@ -178,15 +181,15 @@ def call(xs=None, length=None):
178181
def call(xs):
179182
init_values = [v.value for v in dyn_vars]
180183
try:
181-
turn_on_global_jit()
184+
add_context(name)
182185
dyn_values, out_values = lax.scan(f=fun2scan, init=init_values, xs=xs)
183-
turn_off_global_jit()
186+
del_context(name)
184187
except UnexpectedTracerError as e:
185-
turn_off_global_jit()
188+
del_context(name)
186189
for v, d in zip(dyn_vars, init_values): v._value = d
187190
raise errors.JaxTracerError(variables=dyn_vars) from e
188191
except Exception as e:
189-
turn_off_global_jit()
192+
del_context(name)
190193
for v, d in zip(dyn_vars, init_values): v._value = d
191194
raise e
192195
for v, d in zip(dyn_vars, dyn_values): v._value = d
@@ -255,20 +258,22 @@ def _cond_fun(op):
255258
for v, d in zip(dyn_vars, dyn_values): v._value = d
256259
return as_device_array(cond_fun(static_values))
257260

261+
name = get_unique_name('_brainpy_object_oriented_make_while_')
262+
258263
def call(x=None):
259264
dyn_init = [v.value for v in dyn_vars]
260265
try:
261-
turn_on_global_jit()
266+
add_context(name)
262267
dyn_values, _ = lax.while_loop(cond_fun=_cond_fun,
263268
body_fun=_body_fun,
264269
init_val=(dyn_init, x))
265-
turn_off_global_jit()
270+
del_context(name)
266271
except UnexpectedTracerError as e:
267-
turn_off_global_jit()
272+
del_context(name)
268273
for v, d in zip(dyn_vars, dyn_init): v._value = d
269274
raise errors.JaxTracerError(variables=dyn_vars) from e
270275
except Exception as e:
271-
turn_off_global_jit()
276+
del_context(name)
272277
for v, d in zip(dyn_vars, dyn_init): v._value = d
273278
raise e
274279
for v, d in zip(dyn_vars, dyn_values): v._value = d
@@ -330,6 +335,8 @@ def make_cond(true_fun, false_fun, dyn_vars=None):
330335
if not isinstance(v, JaxArray):
331336
raise ValueError(f'Only support {JaxArray.__name__}, but got {type(v)}')
332337

338+
name = get_unique_name('_brainpy_object_oriented_make_cond_')
339+
333340
if len(dyn_vars) > 0:
334341
def _true_fun(op):
335342
dyn_vals, static_vals = op
@@ -348,25 +355,25 @@ def _false_fun(op):
348355
def call(pred, x=None):
349356
old_values = [v.value for v in dyn_vars]
350357
try:
351-
turn_on_global_jit()
358+
add_context(name)
352359
dyn_values, res = lax.cond(pred, _true_fun, _false_fun, (old_values, x))
353-
turn_off_global_jit()
360+
del_context(name)
354361
except UnexpectedTracerError as e:
355-
turn_off_global_jit()
362+
del_context(name)
356363
for v, d in zip(dyn_vars, old_values): v._value = d
357364
raise errors.JaxTracerError(variables=dyn_vars) from e
358365
except Exception as e:
359-
turn_off_global_jit()
366+
del_context(name)
360367
for v, d in zip(dyn_vars, old_values): v._value = d
361368
raise e
362369
for v, d in zip(dyn_vars, dyn_values): v._value = d
363370
return res
364371

365372
else:
366373
def call(pred, x=None):
367-
turn_on_global_jit()
374+
add_context(name)
368375
res = lax.cond(pred, true_fun, false_fun, x)
369-
turn_off_global_jit()
376+
del_context(name)
370377
return res
371378

372379
return call
@@ -445,6 +452,8 @@ def cond(
445452
if not isinstance(v, Variable):
446453
raise ValueError(f'Only support {Variable.__name__}, but got {type(v)}')
447454

455+
name = get_unique_name('_brainpy_object_oriented_cond_')
456+
448457
# calling the model
449458
if len(dyn_vars) > 0:
450459
def _true_fun(op):
@@ -463,25 +472,25 @@ def _false_fun(op):
463472

464473
old_values = [v.value for v in dyn_vars]
465474
try:
466-
turn_on_global_jit()
475+
add_context(name)
467476
dyn_values, res = lax.cond(pred=pred,
468477
true_fun=_true_fun,
469478
false_fun=_false_fun,
470479
operand=(old_values, operands))
471-
turn_off_global_jit()
480+
del_context(name)
472481
except UnexpectedTracerError as e:
473-
turn_off_global_jit()
482+
del_context(name)
474483
for v, d in zip(dyn_vars, old_values): v._value = d
475484
raise errors.JaxTracerError(variables=dyn_vars) from e
476485
except Exception as e:
477-
turn_off_global_jit()
486+
del_context(name)
478487
for v, d in zip(dyn_vars, old_values): v._value = d
479488
raise e
480489
for v, d in zip(dyn_vars, dyn_values): v._value = d
481490
else:
482-
turn_on_global_jit()
491+
add_context(name)
483492
res = lax.cond(pred, true_fun, false_fun, operands)
484-
turn_off_global_jit()
493+
del_context(name)
485494
return res
486495

487496

@@ -591,7 +600,11 @@ def ifelse(
591600
if show_code: print(codes)
592601
exec(compile(codes.strip(), '', 'exec'), code_scope)
593602
f = code_scope['f']
594-
return f(operands)
603+
name = get_unique_name('_brainpy_object_oriented_ifelse_')
604+
add_context(name)
605+
r = f(operands)
606+
del_context(name)
607+
return r
595608

596609

597610
def for_loop(body_fun: Callable,
@@ -694,22 +707,24 @@ def fun2scan(dyn_vals, x):
694707
results = body_fun(*x)
695708
return [v.value for v in dyn_vars], results
696709

710+
name = get_unique_name('_brainpy_object_oriented_for_loop_')
711+
697712
# functions
698713
init_vals = [v.value for v in dyn_vars]
699714
try:
700-
turn_on_global_jit()
715+
add_context(name)
701716
dyn_vals, out_vals = lax.scan(f=fun2scan,
702717
init=init_vals,
703718
xs=operands,
704719
reverse=reverse,
705720
unroll=unroll)
706-
turn_off_global_jit()
721+
del_context(name)
707722
except UnexpectedTracerError as e:
708-
turn_off_global_jit()
723+
del_context(name)
709724
for v, d in zip(dyn_vars, init_vals): v._value = d
710725
raise errors.JaxTracerError(variables=dyn_vars) from e
711726
except Exception as e:
712-
turn_off_global_jit()
727+
del_context(name)
713728
for v, d in zip(dyn_vars, init_vals): v._value = d
714729
raise e
715730
for v, d in zip(dyn_vars, dyn_vals): v._value = d
@@ -797,19 +812,20 @@ def _cond_fun(op):
797812
r = cond_fun(*static_vals)
798813
return r if isinstance(r, JaxArray) else r
799814

815+
name = get_unique_name('_brainpy_object_oriented_while_loop_')
800816
dyn_init = [v.value for v in dyn_vars]
801817
try:
802-
turn_on_global_jit()
818+
add_context(name)
803819
dyn_values, out = lax.while_loop(cond_fun=_cond_fun,
804820
body_fun=_body_fun,
805821
init_val=(dyn_init, operands))
806-
turn_off_global_jit()
822+
del_context(name)
807823
except UnexpectedTracerError as e:
808-
turn_off_global_jit()
824+
del_context(name)
809825
for v, d in zip(dyn_vars, dyn_init): v._value = d
810826
raise errors.JaxTracerError(variables=dyn_vars) from e
811827
except Exception as e:
812-
turn_off_global_jit()
828+
del_context(name)
813829
for v, d in zip(dyn_vars, dyn_init): v._value = d
814830
raise e
815831
for v, d in zip(dyn_vars, dyn_values): v._value = d

0 commit comments

Comments
 (0)