|
3 | 3 | from functools import partial |
4 | 4 |
|
5 | 5 |
|
6 | | -@partial(jit, static_argnums=(0,)) |
| 6 | + |
7 | 7 | def linear_2D(t, x, params): |
8 | 8 | ''' |
9 | 9 | :param x: 2D vector |
@@ -32,7 +32,7 @@ def linear_2D(t, x, params): |
32 | 32 | return dfx_ |
33 | 33 |
|
34 | 34 |
|
35 | | -@partial(jit, static_argnums=(0,)) |
| 35 | + |
36 | 36 | def cubic_2D(t, x, params): |
37 | 37 | ''' |
38 | 38 | :param x: 2D vector |
@@ -60,7 +60,7 @@ def cubic_2D(t, x, params): |
60 | 60 | return dfx_ |
61 | 61 |
|
62 | 62 |
|
63 | | -@partial(jit, static_argnums=(0,)) |
| 63 | + |
64 | 64 | def lorenz(t, x, params): |
65 | 65 | ''' |
66 | 66 | :param x: 3D vector |
@@ -94,7 +94,7 @@ def lorenz(t, x, params): |
94 | 94 | return jnp.stack([dx, dy, dz], axis=-1) |
95 | 95 |
|
96 | 96 |
|
97 | | -@partial(jit, static_argnums=(0,)) |
| 97 | + |
98 | 98 | def linear_3D(t, x, params): |
99 | 99 | ''' |
100 | 100 | :param x: 3D vector |
@@ -127,7 +127,8 @@ def linear_3D(t, x, params): |
127 | 127 | return jnp.stack([dx, dy, dz], axis=-1) |
128 | 128 |
|
129 | 129 |
|
130 | | -@partial(jit, static_argnums=(0,)) |
| 130 | + |
| 131 | + |
131 | 132 | def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5.0, sigma=1.1): |
132 | 133 | ''' |
133 | 134 | :param x: 3D vector |
@@ -237,3 +238,10 @@ def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5 |
237 | 238 | ax.set_zlabel('z', fontsize=20) |
238 | 239 | plt.grid(True) |
239 | 240 | plt.show() |
| 241 | + |
| 242 | + |
| 243 | + |
| 244 | + |
| 245 | + |
| 246 | + |
| 247 | + |
0 commit comments