Skip to content

Commit 7f87aa8

Browse files
committed
add faster linear algebra routines, utils for least squares ftle
1 parent 5a9ebe0 commit 7f87aa8

File tree

1 file changed

+169
-47
lines changed

1 file changed

+169
-47
lines changed

src/numbacs/utils.py

Lines changed: 169 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88

99
@njit(inline="always")
10-
def gradF_stencil(F, i, j, dx, dy):
10+
def gradF_stencil_2D(F, i, j, dx, dy):
1111
"""
12-
Stencil for computing the gradient of F, a grid of 2D vectors, at i, j,
12+
Stencil for computing the gradient of F, a grid of 2D vectors, at (i, j),
1313
with spacing dx and dy. Not boundary safe.
1414
1515
Parameters
@@ -47,9 +47,9 @@ def gradF_stencil(F, i, j, dx, dy):
4747

4848

4949
@njit(inline="always")
50-
def gradF_aux_stencil(F_aux, i, j, h):
50+
def gradF_aux_stencil_2D(F_aux, i, j, h):
5151
"""
52-
Stencil for computing the gradient of F_aux, a grid of 2D vectors, at i, j,
52+
Stencil for computing the gradient of F_aux, a grid of 2D vectors, at (i, j),
5353
using the aux grid, with spacing h.
5454
5555
Parameters
@@ -85,9 +85,9 @@ def gradF_aux_stencil(F_aux, i, j, h):
8585

8686

8787
@njit(inline="always")
88-
def gradF_main_stencil(F_aux, i, j, dx, dy):
88+
def gradF_main_stencil_2D(F_aux, i, j, dx, dy):
8989
"""
90-
Stencil for computing the gradient of F_aux, a grid of 2D vectors, at i, j,
90+
Stencil for computing the gradient of F_aux, a grid of 2D vectors, at (i, j),
9191
using the main grid, with spacing dx, dy. Not boundary safe.
9292
9393
Parameters
@@ -125,9 +125,9 @@ def gradF_main_stencil(F_aux, i, j, dx, dy):
125125

126126

127127
@njit(inline="always")
128-
def gradUV_stencil(U, V, i, j, dx, dy):
128+
def gradUV_stencil_2D(U, V, i, j, dx, dy):
129129
"""
130-
Stencil for computing the gradient of velocity, defined by U, V, at i, j,
130+
Stencil for computing the gradient of velocity, defined by U, V, at (i, j),
131131
with spacing dx and dy. Not boundary safe.
132132
133133
Parameters
@@ -165,6 +165,100 @@ def gradUV_stencil(U, V, i, j, dx, dy):
165165
return dUdx, dUdy, dVdx, dVdy
166166

167167

168+
@njit(inline="always")
169+
def eigvalsh_max_2D(A):
170+
"""
171+
Computes the maximum eigenvalue for a Hermetian 2x2 array A.
172+
173+
Parameters
174+
----------
175+
A : np.ndarray, shape = (2, 2)
176+
Hermetian 2x2 array.
177+
178+
Returns
179+
-------
180+
float
181+
maximum eigenvalue of A.
182+
183+
"""
184+
185+
a, b, d = A[0, 0], A[0, 1], A[1, 1]
186+
trace = a + d
187+
discriminant = sqrt((a - d) ** 2 + 4 * (b**2))
188+
189+
return 0.5 * (trace + discriminant)
190+
191+
192+
@njit(inline="always")
193+
def inv_2D(A):
194+
"""
195+
Computes the inverse of a 2x2 array A.
196+
197+
Parameters
198+
----------
199+
A : np.ndarray, shape = (2, 2)
200+
2x2 array.
201+
202+
Returns
203+
-------
204+
np.ndarray, shape = (2, 2)
205+
inverse of A.
206+
207+
"""
208+
a, b, c, d = A[0, 0], A[0, 1], A[1, 0], A[1, 1]
209+
210+
det = a * d - b * c
211+
212+
if det != 0:
213+
return np.array([[d, -b], [-c, a]]) / det
214+
else:
215+
return np.zeros((2, 2), numba.float64)
216+
217+
218+
@njit(inline="always")
219+
def vec_dot_2D(v1, v2):
220+
"""
221+
Vector dot product for 2D vectors.
222+
223+
Parameters
224+
----------
225+
v1 : np.ndarray, shape=(2,)
226+
first vector.
227+
v2 : np.ndarray, shape=(2,)
228+
second vector.
229+
230+
Returns
231+
-------
232+
float
233+
dot product.
234+
235+
"""
236+
237+
return v1[0] * v2[0] + v1[1] * v2[1]
238+
239+
240+
@njit(inline="always")
241+
def vec_dot_3D(v1, v2):
242+
"""
243+
Vector dot product for 3D vectors.
244+
245+
Parameters
246+
----------
247+
v1 : np.ndarray, shape=(3,)
248+
first vector.
249+
v2 : np.ndarray, shape=(3,)
250+
second vector.
251+
252+
Returns
253+
-------
254+
float
255+
dot product.
256+
257+
"""
258+
259+
return v1[0] * v2[0] + v1[1] * v2[1] + v1[2] * v2[2]
260+
261+
168262
@njit
169263
def unravel_index(index, shape):
170264
"""
@@ -1172,78 +1266,106 @@ def cart_prod(vecs):
11721266
return prod
11731267

11741268

1175-
@njit(inline="always")
1176-
def eigvalsh_max_2D(A):
1269+
def scipy_dilate_mask(mask, corners=False):
11771270
"""
1178-
Computes the maximum eigenvalue for a Hermetian 2x2 array A.
1271+
A wrapper for scipy.ndimage.binary_dilation() for expanding
1272+
a mask to all grid points that have neighbors that are True.
11791273
11801274
Parameters
11811275
----------
1182-
A : np.ndarray, shape = (2, 2)
1183-
Hermetian 2x2 array.
1276+
mask : np.ndarray, shape=(nx, ny), dtype=bool
1277+
boolean array with True values corresponding to masked values.
1278+
corners : bool, optional
1279+
if True, only cardinal directions are used, if False, corners
1280+
are also used. The default is False.
11841281
11851282
Returns
11861283
-------
1187-
float
1188-
maximum eigenvalue of A.
1284+
np.ndarray, shape=(nx, ny), dtype=bool
1285+
dilated mask.
11891286
11901287
"""
1288+
if not corners:
1289+
structure = generate_binary_structure(2, 1)
1290+
else:
1291+
structure = generate_binary_structure(2, 2)
11911292

1192-
a, b, d = A[0, 0], A[0, 1], A[1, 1]
1193-
trace = a + d
1194-
discriminant = sqrt((a - d) ** 2 + 4 * (b**2))
1195-
1196-
return 0.5 * (trace + discriminant)
1293+
return binary_dilation(mask, structure=structure)
11971294

11981295

1199-
@njit(inline="always")
1200-
def inv_2D(A):
1296+
def lonlat2xyz(Lon, Lat, r, deg2rad=False, return_array=False):
12011297
"""
1202-
Computes the inverse of a 2x2 array A.
1298+
Convert lon, lat positions to x, y, z.
12031299
12041300
Parameters
12051301
----------
1206-
A : np.ndarray, shape = (2, 2)
1207-
2x2 array.
1302+
Lon : np.ndarray, shape=(nx, ny)
1303+
meshgrid of longitude.
1304+
Lat : np.ndarray, shape=(nx, ny)
1305+
meshgrid of latitude.
1306+
r : float
1307+
radius.
1308+
deg2rad : bool, optional
1309+
flag to convert from degree to radians. Lon, Lat must either
1310+
already be in radians, or this flag must be set to True.
1311+
The default is False.
1312+
return_array : bool, optional
1313+
flag to return stacked array instead of tuple. The default is False.
12081314
12091315
Returns
12101316
-------
1211-
np.ndarray, shape = (2, 2)
1212-
inverse of A.
1317+
tuple or np.ndarray
1318+
either tuple or stacked array containing meshgrid of X, Y, Z position.
12131319
12141320
"""
1215-
a, b, c, d = A[0, 0], A[0, 1], A[1, 0], A[1, 1]
12161321

1217-
det = a * d - b * c
1322+
if deg2rad:
1323+
Lon = np.deg2rad(Lon)
1324+
Lat = np.deg2rad(Lat)
12181325

1219-
if det != 0:
1220-
return np.array([[d, -b], [-c, a]]) / det
1326+
Xp = r * np.cos(Lat) * np.cos(Lon)
1327+
Yp = r * np.cos(Lat) * np.sin(Lon)
1328+
Zp = r * np.sin(Lat)
1329+
1330+
if return_array:
1331+
return np.stack((Xp, Yp, Zp), axis=-1)
12211332
else:
1222-
return np.zeros((2, 2), numba.float64)
1333+
return Xp, Yp, Zp
12231334

12241335

1225-
def scipy_dilate_mask(mask, corners=False):
1336+
def local_basis_S2(Lon, Lat, deg2rad=False):
12261337
"""
1227-
A wrapper for scipy.ndimage.binary_dilation() for expanding
1228-
a mask to all grid points that have neighbors that are True.
1338+
Create a local basis on the surface of the sphere (S2) in x, y, z coords.
12291339
12301340
Parameters
12311341
----------
1232-
mask : np.ndarray, shape=(nx, ny), dtype=bool
1233-
boolean array with True values corresponding to masked values.
1234-
corners : bool, optional
1235-
if True, only cardinal directions are used, if False, corners
1236-
are also used. The default is False.
1342+
Lon : np.ndarray, shape=(nx, ny)
1343+
meshgrid of longitude.
1344+
Lat : np.ndarray, shape=(nx, ny)
1345+
meshgrid of latitude.
1346+
deg2rad : bool, optional
1347+
flag to convert from degree to radians. Lon, Lat must either
1348+
already be in radians, or this flag must be set to True.
1349+
The default is False.
12371350
12381351
Returns
12391352
-------
1240-
np.ndarray, shape=(nx, ny), dtype=bool
1241-
dilated mask.
1353+
e1 : np.ndarray, shape=(nx, ny, 2)
1354+
local basis vector in the "east" direction.
1355+
e2 : np.ndarray, shape=(nx, ny, 2)
1356+
local basis vector in the "north" direction.
12421357
12431358
"""
1244-
if not corners:
1245-
structure = generate_binary_structure(2, 1)
1246-
else:
1247-
structure = generate_binary_structure(2, 2)
1359+
nx, ny = Lon.shape
1360+
if deg2rad:
1361+
Lon = np.deg2rad(Lon)
1362+
Lat = np.deg2rad(Lat)
12481363

1249-
return binary_dilation(mask, structure=structure)
1364+
sinLon = np.sin(Lon)
1365+
sinLat = np.sin(Lat)
1366+
cosLon = np.cos(Lon)
1367+
cosLat = np.cos(Lat)
1368+
e1 = np.stack((-sinLon, cosLon, np.zeros((nx, ny), np.float64)), axis=-1)
1369+
e2 = np.stack((-sinLat * cosLon, -sinLat * sinLon, cosLat), axis=-1)
1370+
1371+
return e1, e2

0 commit comments

Comments
 (0)