77
88"""
99
10- from typing import Callable , Union , Optional , Sequence , Dict , Any
10+ from typing import Callable , Union , Optional , Sequence , Dict , Any , Iterable
1111
1212import jax
13-
14- try :
15- from jax .errors import UnexpectedTracerError , ConcretizationTypeError
16- except ImportError :
17- from jax .core import UnexpectedTracerError , ConcretizationTypeError
13+ from jax .errors import UnexpectedTracerError , ConcretizationTypeError
1814
1915from brainpy import errors , tools , check
2016from brainpy ._src .math .ndarray import Variable , add_context , del_context
2117from .abstract import ObjectTransform
2218from .base import BrainPyObject
23- from ._utils import infer_dyn_vars
2419
2520__all__ = [
2621 'jit' ,
@@ -35,7 +30,8 @@ def __init__(
3530 target : callable ,
3631 dyn_vars : Dict [str , Variable ],
3732 child_objs : Dict [str , BrainPyObject ],
38- static_argnames : Optional [Any ] = None ,
33+ static_argnums : Union [int , Iterable [int ], None ] = None ,
34+ static_argnames : Union [str , Iterable [str ], None ] = None ,
3935 device : Optional [Any ] = None ,
4036 name : Optional [str ] = None ,
4137 inline : bool = False ,
@@ -46,12 +42,14 @@ def __init__(
4642
4743 self .register_implicit_vars (dyn_vars )
4844 self .register_implicit_nodes (child_objs )
49-
45+ if hasattr (target , '__self__' ) and isinstance (getattr (target , '__self__' ), BrainPyObject ):
46+ self .register_implicit_nodes (getattr (target , '__self__' ))
5047 self .target = target
5148 self ._all_vars = self .vars ().unique ()
5249
5350 # transformation
5451 self ._f = jax .jit (self ._transform_function ,
52+ static_argnums = jax .tree_util .tree_map (lambda a : a + 1 , static_argnums ),
5553 static_argnames = static_argnames ,
5654 device = device ,
5755 inline = inline ,
@@ -100,7 +98,8 @@ def jit(
10098 func : Callable ,
10199 dyn_vars : Optional [Union [Variable , Sequence [Variable ], Dict [str , Variable ]]] = None ,
102100 child_objs : Optional [Union [BrainPyObject , Sequence [BrainPyObject ], Dict [str , BrainPyObject ]]] = None ,
103- static_argnames : Optional [Union [str , Any ]] = None ,
101+ static_argnums : Union [int , Iterable [int ], None ] = None ,
102+ static_argnames : Union [str , Iterable [str ], None ] = None ,
104103 device : Optional [Any ] = None ,
105104 inline : bool = False ,
106105 keep_unused : bool = False ,
@@ -230,6 +229,7 @@ def jit(
230229 return JITTransform (target = func ,
231230 dyn_vars = dyn_vars ,
232231 child_objs = child_objs ,
232+ static_argnums = static_argnums ,
233233 static_argnames = static_argnames ,
234234 device = device ,
235235 inline = inline ,
0 commit comments