11import jax .numpy as jnp
2- from jax import jit
3- from functools import partial
4-
5-
62
73def linear_2D (t , x , params ):
8- '''
9- :param x: 2D vector
10- type: jax array
11- shape:(2,)
12-
13- :param t: Unused
14-
15- :param params: Unused
16-
17- :return: 2D vector: [
18- -0.1 * x[0] + 2.0 * x[1],
19- -2.0 * x[0] - 0.1 * x[1]
20- ]
21- type: jax array
22- shape:(2,)
23-
24- ------------------------------------------
25- * suggested init value-
26- x0 = jnp.array([3, -1.5] )
27- '''
4+ """
5+ * suggested init value - x0 = jnp.array([3, -1.5])
6+
7+ Args:
8+ param x: 2D vector
9+ type: jax array
10+ shape:(2,)
11+
12+ param t: Unused
13+
14+ param params: Unused
15+
16+ Returns:
17+ 2D vector: [
18+ -0.1 * x[0] + 2.0 * x[1],
19+ -2.0 * x[0] - 0.1 * x[1]
20+ ]
21+ type: jax array
22+ shape:(2, )
23+ """
2824 coeff = jnp .array ([[- 0.1 , 2 ],
2925 [- 2 , - 0.1 ]]).T
3026 dfx_ = jnp .matmul (x , coeff )
3127
3228 return dfx_
3329
34-
35-
3630def cubic_2D (t , x , params ):
37- '''
38- :param x: 2D vector
39- type: jax array
40- shape:(2,)
41-
42- :param t: Unused
43-
44- :param params: Unused
45-
46- :return: 2D vector: [
47- -0.1 * x[0] ** 3 + 2.0 * x[1] ** 3,
48- -2.0 * x[0] ** 3 - 0.1 * x[1] ** 3,
49- ]
50- type: jax array
51- shape:(2,)
52-
53- ------------------------------------------
54- * suggested init value-
55- x0 = jnp.array([2., 0.] )
56- '''
31+ """
32+ suggested init value - x0 = jnp.array([2., 0.])
33+
34+ Args:
35+ param x: 2D vector
36+ type: jax array
37+ shape: (2,)
38+
39+ param t: Unused
40+
41+ param params: Unused
42+
43+ Returns:
44+ 2D vector: [
45+ -0.1 * x[0] ** 3 + 2.0 * x[1] ** 3,
46+ -2.0 * x[0] ** 3 - 0.1 * x[1] ** 3,
47+ ]
48+ type: jax array
49+ shape:(2, )
50+ """
5751 coeff = jnp .array ([[- 0.1 , 2 ],
5852 [- 2 , - 0.1 ]]).T
5953 dfx_ = jnp .matmul (x ** 3 , coeff )
6054 return dfx_
6155
62-
63-
6456def lorenz (t , x , params ):
65- '''
66- :param x: 3D vector
67- type: jax array
68- shape:(3,)
69-
70- :param t: Unused
71-
72- :param params: Unused
73-
74- :return: 3D vector: [
75- 10 * (x[1] - x[0]),
76- x[0] * (28 - x[2]) - x[1],
77- x[0] * x[1] - 8 / 3 * x[2],
78- ]
79- type: jax array
80- shape:(3,)
81-
82- ------------------------------------------
83- * suggested init value-
84- x0 = jnp.array([-8, 7, 27] )
85- '''
57+ """
58+ suggested init value - x0 = jnp.array([-8, 7, 27])
59+
60+ Args:
61+ param x: 3D vector
62+ type: jax array
63+ shape: (3,)
64+
65+ param t: Unused
66+
67+ param params: Unused
68+
69+ Returns:
70+ 3D vector: [
71+ 10 * (x[1] - x[0]),
72+ x[0] * (28 - x[2]) - x[1],
73+ x[0] * x[1] - 8 / 3 * x[2],
74+ ]
75+ type: jax array
76+ shape:(3, )
77+ """
8678 x_ = x [..., 0 ]
8779 y_ = x [..., 1 ]
8880 z_ = x [..., 2 ]
8981
9082 dx = 10 * y_ - 10 * x_
9183 dy = 28 * x_ - x_ * z_ - y_
9284 dz = x_ * y_ - 8 / 3 * z_
93-
9485 return jnp .stack ([dx , dy , dz ], axis = - 1 )
9586
9687
97-
9888def linear_3D (t , x , params ):
99- '''
100- :param x: 3D vector
101- type: jax array
102- shape:(3,)
103-
104- :param t: Unused
105-
106- :param params: Unused
107-
108- :return: 3D vector: [
109- -0.1 * x[0] + 2 * x[1],
110- -2 * x[0] - 0.1 * x[1],
111- -0.3 * x[2]
112- ]
113- type: jax array
114- shape:(3,)
115- ------------------------------------------
116- * suggested init value-
117- x0 = jnp.array([1, 1., -1])
118- '''
89+ """
90+ suggested init value - x0 = jnp.array([1, 1., -1])
91+
92+ Args:
93+ param x: 3D vector
94+ type: jax array
95+ shape: (3,)
96+
97+ param t: Unused
98+
99+ param params: Unused
100+
101+ Returns:
102+ 3D vector: [
103+ -0.1 * x[0] + 2 * x[1],
104+ -2 * x[0] - 0.1 * x[1],
105+ -0.3 * x[2]
106+ ]
107+ type: jax array
108+ shape:(3,)
109+ """
119110 x_ = x [..., 0 ]
120111 y_ = x [..., 1 ]
121112 z_ = x [..., 2 ]
@@ -130,27 +121,28 @@ def linear_3D(t, x, params):
130121
131122
132123def oscillator (t , x , params , mu1 = 0.05 , mu2 = - 0.01 , omega = 3.0 , alpha = - 2.0 , beta = - 5.0 , sigma = 1.1 ):
133- '''
134- :param x: 3D vector
135- type: jax array
136- shape:(3,)
137-
138- :param t: Unused
139-
140- :param params: Unused
141-
142- :return: 3D vector: [
143- mu1 * x[0] + sigma * x[0] * x[1],
144- mu2 * x[1] + (omega + alpha * x[1] + beta * x[2]) * x[2] - sigma * x[0] ** 2,
145- mu2 * x[2] - (omega + alpha * x[1] + beta * x[2]) * x[1],
146- ]
147-
148- type: jax array
149- shape:(3,)
150- ------------------------------------------
151- * suggested init value-
152- x0 = jnp.array([0.5, 0.05, 0.1])
153- '''
124+ """
125+ suggested init value - x0 = jnp.array([0.5, 0.05, 0.1])
126+
127+ Args:
128+ param x: 3D vector
129+ type: jax array
130+ shape: (3,)
131+
132+ param t: Unused
133+
134+ param params: Unused
135+
136+ Returns:
137+ 3D vector: [
138+ mu1 * x[0] + sigma * x[0] * x[1],
139+ mu2 * x[1] + (omega + alpha * x[1] + beta * x[2]) * x[2] - sigma * x[0] ** 2,
140+ mu2 * x[2] - (omega + alpha * x[1] + beta * x[2]) * x[1],
141+ ]
142+
143+ type: jax array
144+ shape:(3,)
145+ """
154146 x_ = x [..., 0 ]
155147 y_ = x [..., 1 ]
156148 z_ = x [..., 2 ]
@@ -162,13 +154,7 @@ def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5
162154 return jnp .stack ([dx , dy , dz ], axis = - 1 )
163155
164156
165-
166-
167-
168-
169-
170-
171-
157+ ## some testing/driver code to check the ODEs themselves
172158if __name__ == "__main__" :
173159 import matplotlib .pyplot as plt
174160 from ngclearn .utils .diffeq .ode_utils import solve_ode
0 commit comments