Skip to content

Commit d5def75

Browse files
authored
Update odes.py (#87)
removed @partial(jit, static_argnums=(0,))
1 parent cf53968 commit d5def75

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

ngclearn/utils/diffeq/odes.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import partial
44

55

6-
@partial(jit, static_argnums=(0,))
6+
77
def linear_2D(t, x, params):
88
'''
99
:param x: 2D vector
@@ -32,7 +32,7 @@ def linear_2D(t, x, params):
3232
return dfx_
3333

3434

35-
@partial(jit, static_argnums=(0,))
35+
3636
def cubic_2D(t, x, params):
3737
'''
3838
:param x: 2D vector
@@ -60,7 +60,7 @@ def cubic_2D(t, x, params):
6060
return dfx_
6161

6262

63-
@partial(jit, static_argnums=(0,))
63+
6464
def lorenz(t, x, params):
6565
'''
6666
:param x: 3D vector
@@ -94,7 +94,7 @@ def lorenz(t, x, params):
9494
return jnp.stack([dx, dy, dz], axis=-1)
9595

9696

97-
@partial(jit, static_argnums=(0,))
97+
9898
def linear_3D(t, x, params):
9999
'''
100100
:param x: 3D vector
@@ -127,7 +127,8 @@ def linear_3D(t, x, params):
127127
return jnp.stack([dx, dy, dz], axis=-1)
128128

129129

130-
@partial(jit, static_argnums=(0,))
130+
131+
131132
def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5.0, sigma=1.1):
132133
'''
133134
:param x: 3D vector
@@ -237,3 +238,10 @@ def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5
237238
ax.set_zlabel('z', fontsize=20)
238239
plt.grid(True)
239240
plt.show()
241+
242+
243+
244+
245+
246+
247+

0 commit comments

Comments
 (0)