Skip to content

Commit 9de3c98

Browse files
author
Alexander Ororbia
committed
cleaned up doc-strings in odes.py to comply w/ ngc-learn format
1 parent 1d15f1f commit 9de3c98

File tree

1 file changed

+105
-119
lines changed

1 file changed

+105
-119
lines changed

ngclearn/utils/diffeq/odes.py

Lines changed: 105 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,121 +1,112 @@
11
import jax.numpy as jnp
2-
from jax import jit
3-
from functools import partial
4-
5-
62

73
def linear_2D(t, x, params):
8-
'''
9-
:param x: 2D vector
10-
type: jax array
11-
shape:(2,)
12-
13-
:param t: Unused
14-
15-
:param params: Unused
16-
17-
:return: 2D vector: [
18-
-0.1 * x[0] + 2.0 * x[1],
19-
-2.0 * x[0] - 0.1 * x[1]
20-
]
21-
type: jax array
22-
shape:(2,)
23-
24-
------------------------------------------
25-
* suggested init value-
26-
x0 = jnp.array([3, -1.5])
27-
'''
4+
"""
5+
* suggested init value - x0 = jnp.array([3, -1.5])
6+
7+
Args:
8+
param x: 2D vector
9+
type: jax array
10+
shape:(2,)
11+
12+
param t: Unused
13+
14+
param params: Unused
15+
16+
Returns:
17+
2D vector: [
18+
-0.1 * x[0] + 2.0 * x[1],
19+
-2.0 * x[0] - 0.1 * x[1]
20+
]
21+
type: jax array
22+
shape:(2,)
23+
"""
2824
coeff = jnp.array([[-0.1, 2],
2925
[-2, -0.1]]).T
3026
dfx_ = jnp.matmul(x, coeff)
3127

3228
return dfx_
3329

34-
35-
3630
def cubic_2D(t, x, params):
37-
'''
38-
:param x: 2D vector
39-
type: jax array
40-
shape:(2,)
41-
42-
:param t: Unused
43-
44-
:param params: Unused
45-
46-
:return: 2D vector: [
47-
-0.1 * x[0] ** 3 + 2.0 * x[1] ** 3,
48-
-2.0 * x[0] ** 3 - 0.1 * x[1] ** 3,
49-
]
50-
type: jax array
51-
shape:(2,)
52-
53-
------------------------------------------
54-
* suggested init value-
55-
x0 = jnp.array([2., 0.])
56-
'''
31+
"""
32+
suggested init value - x0 = jnp.array([2., 0.])
33+
34+
Args:
35+
param x: 2D vector
36+
type: jax array
37+
shape: (2,)
38+
39+
param t: Unused
40+
41+
param params: Unused
42+
43+
Returns:
44+
2D vector: [
45+
-0.1 * x[0] ** 3 + 2.0 * x[1] ** 3,
46+
-2.0 * x[0] ** 3 - 0.1 * x[1] ** 3,
47+
]
48+
type: jax array
49+
shape:(2,)
50+
"""
5751
coeff = jnp.array([[-0.1, 2],
5852
[-2, -0.1]]).T
5953
dfx_ = jnp.matmul(x**3, coeff)
6054
return dfx_
6155

62-
63-
6456
def lorenz(t, x, params):
65-
'''
66-
:param x: 3D vector
67-
type: jax array
68-
shape:(3,)
69-
70-
:param t: Unused
71-
72-
:param params: Unused
73-
74-
:return: 3D vector: [
75-
10 * (x[1] - x[0]),
76-
x[0] * (28 - x[2]) - x[1],
77-
x[0] * x[1] - 8 / 3 * x[2],
78-
]
79-
type: jax array
80-
shape:(3,)
81-
82-
------------------------------------------
83-
* suggested init value-
84-
x0 = jnp.array([-8, 7, 27])
85-
'''
57+
"""
58+
suggested init value - x0 = jnp.array([-8, 7, 27])
59+
60+
Args:
61+
param x: 3D vector
62+
type: jax array
63+
shape: (3,)
64+
65+
param t: Unused
66+
67+
param params: Unused
68+
69+
Returns:
70+
3D vector: [
71+
10 * (x[1] - x[0]),
72+
x[0] * (28 - x[2]) - x[1],
73+
x[0] * x[1] - 8 / 3 * x[2],
74+
]
75+
type: jax array
76+
shape:(3,)
77+
"""
8678
x_ = x[..., 0]
8779
y_ = x[..., 1]
8880
z_ = x[..., 2]
8981

9082
dx = 10 * y_ - 10 * x_
9183
dy = 28 * x_ - x_ * z_ - y_
9284
dz = x_ * y_ - 8 / 3 * z_
93-
9485
return jnp.stack([dx, dy, dz], axis=-1)
9586

9687

97-
9888
def linear_3D(t, x, params):
99-
'''
100-
:param x: 3D vector
101-
type: jax array
102-
shape:(3,)
103-
104-
:param t: Unused
105-
106-
:param params: Unused
107-
108-
:return: 3D vector: [
109-
-0.1 * x[0] + 2 * x[1],
110-
-2 * x[0] - 0.1 * x[1],
111-
-0.3 * x[2]
112-
]
113-
type: jax array
114-
shape:(3,)
115-
------------------------------------------
116-
* suggested init value-
117-
x0 = jnp.array([1, 1., -1])
118-
'''
89+
"""
90+
suggested init value - x0 = jnp.array([1, 1., -1])
91+
92+
Args:
93+
param x: 3D vector
94+
type: jax array
95+
shape: (3,)
96+
97+
param t: Unused
98+
99+
param params: Unused
100+
101+
Returns:
102+
3D vector: [
103+
-0.1 * x[0] + 2 * x[1],
104+
-2 * x[0] - 0.1 * x[1],
105+
-0.3 * x[2]
106+
]
107+
type: jax array
108+
shape:(3,)
109+
"""
119110
x_ = x[..., 0]
120111
y_ = x[..., 1]
121112
z_ = x[..., 2]
@@ -130,27 +121,28 @@ def linear_3D(t, x, params):
130121

131122

132123
def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5.0, sigma=1.1):
133-
'''
134-
:param x: 3D vector
135-
type: jax array
136-
shape:(3,)
137-
138-
:param t: Unused
139-
140-
:param params: Unused
141-
142-
:return: 3D vector: [
143-
mu1 * x[0] + sigma * x[0] * x[1],
144-
mu2 * x[1] + (omega + alpha * x[1] + beta * x[2]) * x[2] - sigma * x[0] ** 2,
145-
mu2 * x[2] - (omega + alpha * x[1] + beta * x[2]) * x[1],
146-
]
147-
148-
type: jax array
149-
shape:(3,)
150-
------------------------------------------
151-
* suggested init value-
152-
x0 = jnp.array([0.5, 0.05, 0.1])
153-
'''
124+
"""
125+
suggested init value - x0 = jnp.array([0.5, 0.05, 0.1])
126+
127+
Args:
128+
param x: 3D vector
129+
type: jax array
130+
shape: (3,)
131+
132+
param t: Unused
133+
134+
param params: Unused
135+
136+
Returns:
137+
3D vector: [
138+
mu1 * x[0] + sigma * x[0] * x[1],
139+
mu2 * x[1] + (omega + alpha * x[1] + beta * x[2]) * x[2] - sigma * x[0] ** 2,
140+
mu2 * x[2] - (omega + alpha * x[1] + beta * x[2]) * x[1],
141+
]
142+
143+
type: jax array
144+
shape:(3,)
145+
"""
154146
x_ = x[..., 0]
155147
y_ = x[..., 1]
156148
z_ = x[..., 2]
@@ -162,13 +154,7 @@ def oscillator(t, x, params, mu1=0.05, mu2=-0.01, omega=3.0, alpha=-2.0, beta=-5
162154
return jnp.stack([dx, dy, dz], axis=-1)
163155

164156

165-
166-
167-
168-
169-
170-
171-
157+
## some testing/driver code to check the ODEs themselves
172158
if __name__ == "__main__":
173159
import matplotlib.pyplot as plt
174160
from ngclearn.utils.diffeq.ode_utils import solve_ode

0 commit comments

Comments
 (0)