|
1 | 1 | # -*- coding: utf-8 -*- |
| 2 | + |
2 | 3 | import functools |
3 | 4 | import numbers |
4 | 5 | from typing import Union, Sequence, Any, Dict, Callable, Optional |
|
12 | 13 |
|
13 | 14 | from brainpy import errors, tools |
14 | 15 | from brainpy._src.math.interoperability import as_jax |
15 | | -from brainpy._src.math.ndarray import (Array, ) |
| 16 | +from brainpy._src.math.ndarray import (Array, _as_jax_array_) |
16 | 17 | from .base import BrainPyObject, ObjectTransform |
17 | 18 | from .naming import ( |
18 | 19 | get_unique_name, |
@@ -421,11 +422,27 @@ def call(pred, x=None): |
421 | 422 | return ControlObject(call, dyn_vars, repr_fun={'true_fun': true_fun, 'false_fun': false_fun}) |
422 | 423 |
|
423 | 424 |
|
| 425 | +@functools.cache |
| 426 | +def _warp(f): |
| 427 | + @functools.wraps(f) |
| 428 | + def new_f(*args, **kwargs): |
| 429 | + return jax.tree_map(_as_jax_array_, f(*args, **kwargs), is_leaf=lambda a: isinstance(a, Array)) |
| 430 | + |
| 431 | + return new_f |
| 432 | + |
| 433 | + |
| 434 | +def _warp_data(data): |
| 435 | + def new_f(*args, **kwargs): |
| 436 | + return jax.tree_map(_as_jax_array_, data, is_leaf=lambda a: isinstance(a, Array)) |
| 437 | + |
| 438 | + return new_f |
| 439 | + |
| 440 | + |
424 | 441 | def _check_f(f): |
425 | 442 | if callable(f): |
426 | | - return f |
| 443 | + return _warp(f) |
427 | 444 | else: |
428 | | - return (lambda *args, **kwargs: f) |
| 445 | + return _warp_data(f) |
429 | 446 |
|
430 | 447 |
|
431 | 448 | def _check_sequence(a): |
@@ -557,7 +574,7 @@ def _if_else_return2(conditions, branches): |
557 | 574 | return branches[-1] |
558 | 575 |
|
559 | 576 |
|
560 | | -def all_equal(iterator): |
| 577 | +def _all_equal(iterator): |
561 | 578 | iterator = iter(iterator) |
562 | 579 | try: |
563 | 580 | first = next(iterator) |
@@ -671,7 +688,7 @@ def ifelse( |
671 | 688 | else: |
672 | 689 | rets = [jax.eval_shape(branch, *operands) for branch in branches] |
673 | 690 | trees = [jax.tree_util.tree_structure(ret) for ret in rets] |
674 | | - if not all_equal(trees): |
| 691 | + if not _all_equal(trees): |
675 | 692 | msg = 'All returns in branches should have the same tree structure. But we got:\n' |
676 | 693 | for tree in trees: |
677 | 694 | msg += f'- {tree}\n' |
@@ -914,12 +931,14 @@ def fun2scan(carry, x): |
914 | 931 | carry, results = body_fun(carry, x) |
915 | 932 | if progress_bar: |
916 | 933 | id_tap(lambda *arg: bar.update(), ()) |
| 934 | + carry = jax.tree_map(_as_jax_array_, carry, is_leaf=lambda a: isinstance(a, Array)) |
917 | 935 | return (dyn_vars.dict_data(), carry), results |
918 | 936 |
|
919 | 937 | if remat: |
920 | 938 | fun2scan = jax.checkpoint(fun2scan) |
921 | 939 |
|
922 | 940 | def call(init, operands): |
| 941 | + init = jax.tree_map(_as_jax_array_, init, is_leaf=lambda a: isinstance(a, Array)) |
923 | 942 | return jax.lax.scan(f=fun2scan, |
924 | 943 | init=(dyn_vars.dict_data(), init), |
925 | 944 | xs=operands, |
@@ -991,19 +1010,21 @@ def scan( |
991 | 1010 | bar = tqdm(total=num_total) |
992 | 1011 |
|
993 | 1012 | dyn_vars = get_stack_cache(body_fun) |
994 | | - if dyn_vars is None: |
995 | | - with new_transform('scan'): |
996 | | - with VariableStack() as dyn_vars: |
997 | | - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) |
998 | | - if current_transform_number() > 1: |
999 | | - rets = transform(init, operands) |
1000 | | - else: |
1001 | | - rets = jax.eval_shape(transform, init, operands) |
1002 | | - cache_stack(body_fun, dyn_vars) # cache |
1003 | | - if current_transform_number(): |
1004 | | - return rets[0][1], rets[1] |
1005 | | - del rets |
| 1013 | + if not jax.config.jax_disable_jit: |
| 1014 | + if dyn_vars is None: |
| 1015 | + with new_transform('scan'): |
| 1016 | + with VariableStack() as dyn_vars: |
| 1017 | + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) |
| 1018 | + if current_transform_number() > 1: |
| 1019 | + rets = transform(init, operands) |
| 1020 | + else: |
| 1021 | + rets = jax.eval_shape(transform, init, operands) |
| 1022 | + cache_stack(body_fun, dyn_vars) # cache |
| 1023 | + if current_transform_number(): |
| 1024 | + return rets[0][1], rets[1] |
| 1025 | + del rets |
1006 | 1026 |
|
| 1027 | + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars |
1007 | 1028 | transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) |
1008 | 1029 | (dyn_vals, carry), out_vals = transform(init, operands) |
1009 | 1030 | for key in dyn_vars.keys(): |
|
0 commit comments