Skip to content

Commit adfb0b1

Browse files
committed
support static_argnums in brainpy.math.jit
1 parent 061de39 commit adfb0b1

File tree

1 file changed

+10
-10
lines changed
  • brainpy/_src/math/object_transform

1 file changed

+10
-10
lines changed

brainpy/_src/math/object_transform/jit.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,15 @@
77
88
"""
99

10-
from typing import Callable, Union, Optional, Sequence, Dict, Any
10+
from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable
1111

1212
import jax
13-
14-
try:
15-
from jax.errors import UnexpectedTracerError, ConcretizationTypeError
16-
except ImportError:
17-
from jax.core import UnexpectedTracerError, ConcretizationTypeError
13+
from jax.errors import UnexpectedTracerError, ConcretizationTypeError
1814

1915
from brainpy import errors, tools, check
2016
from brainpy._src.math.ndarray import Variable, add_context, del_context
2117
from .abstract import ObjectTransform
2218
from .base import BrainPyObject
23-
from ._utils import infer_dyn_vars
2419

2520
__all__ = [
2621
'jit',
@@ -35,7 +30,8 @@ def __init__(
3530
target: callable,
3631
dyn_vars: Dict[str, Variable],
3732
child_objs: Dict[str, BrainPyObject],
38-
static_argnames: Optional[Any] = None,
33+
static_argnums: Union[int, Iterable[int], None] = None,
34+
static_argnames: Union[str, Iterable[str], None] = None,
3935
device: Optional[Any] = None,
4036
name: Optional[str] = None,
4137
inline: bool = False,
@@ -46,12 +42,14 @@ def __init__(
4642

4743
self.register_implicit_vars(dyn_vars)
4844
self.register_implicit_nodes(child_objs)
49-
45+
if hasattr(target, '__self__') and isinstance(getattr(target, '__self__'), BrainPyObject):
46+
self.register_implicit_nodes(getattr(target, '__self__'))
5047
self.target = target
5148
self._all_vars = self.vars().unique()
5249

5350
# transformation
5451
self._f = jax.jit(self._transform_function,
52+
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums),
5553
static_argnames=static_argnames,
5654
device=device,
5755
inline=inline,
@@ -100,7 +98,8 @@ def jit(
10098
func: Callable,
10199
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
102100
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
103-
static_argnames: Optional[Union[str, Any]] = None,
101+
static_argnums: Union[int, Iterable[int], None] = None,
102+
static_argnames: Union[str, Iterable[str], None] = None,
104103
device: Optional[Any] = None,
105104
inline: bool = False,
106105
keep_unused: bool = False,
@@ -230,6 +229,7 @@ def jit(
230229
return JITTransform(target=func,
231230
dyn_vars=dyn_vars,
232231
child_objs=child_objs,
232+
static_argnums=static_argnums,
233233
static_argnames=static_argnames,
234234
device=device,
235235
inline=inline,

0 commit comments

Comments
 (0)