Skip to content

Commit f88b005

Browse files
authored
Merge pull request #94 from vadmbertr/diff-op
Simplify mask handling in diff operators
2 parents c4d3eb6 + 20e5843 commit f88b005

File tree

3 files changed

+53
-180
lines changed

3 files changed

+53
-180
lines changed

jaxparrow/geostrophy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def geostrophy(
6161
coriolis_factor_t = geometry.compute_coriolis_factor(lat_t)
6262

6363
# Handle spurious and masked data
64-
ssh_t = sanitize.sanitize_data(ssh_t, 0, is_land)
64+
ssh_t = sanitize.sanitize_data(ssh_t, jnp.nan, is_land)
6565

6666
u_geos_u, v_geos_v = _geostrophy(ssh_t, dx_t, dy_t, coriolis_factor_t, is_land)
6767

jaxparrow/utils/operators.py

Lines changed: 52 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import jax.numpy as jnp
55
from jaxtyping import Array, Float
66

7-
from .sanitize import handle_land_boundary
8-
97

108
def interpolation(
119
field: Float[Array, "lat lon"],
@@ -40,44 +38,34 @@ def interpolation(
4038
field : Float[Array, "lat lon"]
4139
Interpolated field
4240
"""
43-
def do_interpolate(field_b, field_f, mask_b, mask_f, pad_left):
44-
field_b, field_f = handle_land_boundary(field_b, field_f, mask_b, mask_f, pad_left)
45-
return 0.5 * (field_b + field_f)
46-
47-
def axis0(pad_left):
48-
field_b, field_f = field[:-1, :], field[1:, :]
49-
mask_b, mask_f = mask[:-1, :], mask[1:, :]
50-
midpoint_values = do_interpolate(field_b, field_f, mask_b, mask_f, pad_left)
51-
52-
arr = lax.cond(
53-
pad_left,
54-
lambda: jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge"),
55-
lambda: jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge")
56-
)
57-
58-
return arr
59-
60-
def axis1(pad_left):
61-
field_b, field_f = field[:, :-1], field[:, 1:]
62-
mask_b, mask_f = mask[:, :-1], mask[:, 1:]
63-
midpoint_values = do_interpolate(field_b, field_f, mask_b, mask_f, pad_left)
64-
65-
arr = lax.cond(
66-
pad_left,
67-
lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge"),
68-
lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge")
69-
)
70-
71-
return arr
72-
73-
field = lax.cond(
74-
axis == 0,
75-
lambda pad_left: axis0(pad_left),
76-
lambda pad_left: axis1(pad_left),
77-
padding == "left"
41+
f = jnp.moveaxis(field, axis, -1)
42+
43+
mid = (f[:, :-1] + f[:, 1:]) * 0.5
44+
45+
# handle mask: extrapolate at land boundaries (up to 1 cell)
46+
mid = jnp.where(
47+
jnp.isnan(mid),
48+
f[:, :-1],
49+
mid
50+
)
51+
mid = jnp.where(
52+
jnp.isnan(mid),
53+
f[:, 1:],
54+
mid
55+
)
56+
57+
# extrapolate at the domain boundary
58+
mid = lax.cond(
59+
padding == "left",
60+
lambda: jnp.concatenate([f[:, :1], mid], axis=-1),
61+
lambda: jnp.concatenate([mid, f[:, -1:]], axis=-1)
7862
)
7963

80-
return field
64+
mid = jnp.moveaxis(mid, -1, axis)
65+
66+
mid = jnp.where(mask, jnp.nan, mid)
67+
68+
return mid
8169

8270

8371
def derivative(
@@ -116,41 +104,31 @@ def derivative(
116104
field : Float[Array, "lat lon"]
117105
Interpolated field
118106
"""
119-
def do_differentiate(field_b, field_f, mask_b, mask_f, pad_left):
120-
field_b, field_f = handle_land_boundary(field_b, field_f, mask_b, mask_f, pad_left)
121-
return field_f - field_b
122-
123-
def axis0(pad_left):
124-
field_b, field_f = field[:-1, :], field[1:, :]
125-
mask_b, mask_f = mask[:-1, :], mask[1:, :]
126-
midpoint_values = do_differentiate(field_b, field_f, mask_b, mask_f, pad_left)
127-
128-
arr = lax.cond(
129-
pad_left,
130-
lambda: jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge"),
131-
lambda: jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge")
132-
)
133-
134-
return arr
135-
136-
def axis1(pad_left):
137-
field_b, field_f = field[:, :-1], field[:, 1:]
138-
mask_b, mask_f = mask[:, :-1], mask[:, 1:]
139-
midpoint_values = do_differentiate(field_b, field_f, mask_b, mask_f, pad_left)
140-
141-
arr = lax.cond(
142-
pad_left,
143-
lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge"),
144-
lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge")
145-
)
146-
147-
return arr
148-
149-
field = lax.cond(
150-
axis == 0,
151-
lambda pad_left: axis0(pad_left),
152-
lambda pad_left: axis1(pad_left),
153-
padding == "left"
107+
f = jnp.moveaxis(field, axis, -1)
108+
109+
mid = jnp.diff(f, axis=-1)
110+
111+
# handle mask: extrapolate at land boundaries (up to 1 cell)
112+
mid = jnp.where(
113+
jnp.isnan(mid),
114+
jnp.pad(mid[:, 1:], pad_width=((0, 0), (0, 1)), mode="edge"),
115+
mid
154116
)
117+
mid = jnp.where(
118+
jnp.isnan(mid),
119+
jnp.pad(mid[:, :-1], pad_width=((0, 0), (1, 0)), mode="edge"),
120+
mid
121+
)
122+
123+
# extrapolate at the domain boundary
124+
mid = lax.cond(
125+
padding == "left",
126+
lambda: jnp.pad(mid, pad_width=((0, 0), (1, 0)), mode="edge"),
127+
lambda: jnp.pad(mid, pad_width=((0, 0), (0, 1)), mode="edge" )
128+
)
129+
130+
mid = jnp.moveaxis(mid, -1, axis)
131+
132+
mid = jnp.where(mask, jnp.nan, mid)
155133

156-
return field / dxy
134+
return mid / dxy

jaxparrow/utils/sanitize.py

Lines changed: 0 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from jax import lax
22
import jax.numpy as jnp
33
from jaxtyping import Array, Float
4-
import numpy as np
5-
from scipy import interpolate
64

75

86
def sanitize_data(
@@ -57,106 +55,3 @@ def init_land_mask(
5755
if mask is None:
5856
mask = ~jnp.isfinite(field)
5957
return mask
60-
61-
62-
def handle_land_boundary(
63-
field1: Float[Array, "lat lon"],
64-
field2: Float[Array, "lat lon"],
65-
mask1: Float[Array, "lat lon"],
66-
mask2: Float[Array, "lat lon"],
67-
pad_left: bool
68-
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"]]:
69-
"""
70-
Replaces the masked values of ``field1`` (``field2``) with values of ``field2`` (``field1``), element-wise.
71-
72-
It allows computing more coherent values when applying grid operators.
73-
In such cases, ``field1`` and ``field2`` are left and right shifted versions of a field (along one of the axes).
74-
75-
Parameters
76-
----------
77-
field1 : Float[Array, "lat lon"]
78-
A field
79-
field2 : Float[Array, "lat lon"]
80-
Another field
81-
mask1 : Float[Array, "lat lon"]
82-
A mask defining the marine area of ``field1`` spatial domain; `1` or `True` stands for masked (i.e. land)
83-
mask2 : Float[Array, "lat lon"]
84-
A mask defining the marine area of ``field2`` spatial domain; `1` or `True` stands for masked (i.e. land)
85-
pad_left : bool
86-
If `True`, apply padding in the `left` direction (i.e. `West` or `South`) ;
87-
if `False`, apply padding in the `right` direction (i.e. `East` or `North`).
88-
89-
Returns
90-
-------
91-
field1 : Float[Array, "lat lon"]
92-
A field whose masked values have been replaced with the ones from ``field2``
93-
field2 : Float[Array, "lat lon"]
94-
A field whose masked values have been replaced with the ones from ``field1``
95-
"""
96-
field1, field2 = lax.cond(
97-
pad_left,
98-
lambda: (jnp.where(mask1, field2, field1), field2),
99-
lambda: (field1, jnp.where(mask2, field1, field2))
100-
)
101-
return field1, field2
102-
103-
104-
def sanitize_grid_np(
105-
lat: Float[Array, "lat lon"],
106-
lon: Float[Array, "lat lon"],
107-
mask: Float[Array, "lat lon"] = None
108-
) -> [Float[Array, "lat lon"], Float[Array, "lat lon"]]:
109-
"""
110-
Sanitizes (unstructured) grids by interpolated and extrapolated `nan` or masked values to avoid spurious
111-
(`0`, `nan`, `inf`) spatial steps and Coriolis factors.
112-
113-
Helper function written using pure ``numpy`` and ``scipy``, and as such not used internally,
114-
because incompatible with ``jax.vmap`` and likes.
115-
Should be used before calling ``jaxparrow.geostrophy`` or ``jaxparrow.cyclogeostrophy``
116-
in case of suspicious latitudes or longitudes T grids.
117-
118-
Caution: because it uses ``scipy.interpolate.RBFInterpolator``,
119-
it's memory usage grows quadratically with the number of grid points.
120-
121-
Parameters
122-
----------
123-
lat : Float[Array, "lat lon"]
124-
Grid latitudes
125-
lon : Float[Array, "lat lon"]
126-
Grid longitudes
127-
mask : Float[Array, "lat lon"], optional
128-
Mask to apply, `1` or `True` for masked, defaults to `None`
129-
130-
Returns
131-
-------
132-
lat : Float[Array, "lat lon"]
133-
Grid latitudes
134-
lon : Float[Array, "lat lon"]
135-
Grid longitudes
136-
"""
137-
def fill_nan(arr: Float[Array, "lat lon"]) -> Float[Array, "lat lon"]:
138-
x = np.arange(0, arr.shape[1])
139-
y = np.arange(0, arr.shape[0])
140-
# mask invalid values
141-
arr = np.ma.masked_invalid(arr)
142-
xx, yy = np.meshgrid(x, y)
143-
# get only the valid values
144-
valid_x = xx[~arr.mask]
145-
valid_y = yy[~arr.mask]
146-
valid_arr = arr[~arr.mask]
147-
rbf = interpolate.RBFInterpolator(np.array([valid_x, valid_y]).T, valid_arr)
148-
# get the invalid ones
149-
invalid_x = xx[arr.mask]
150-
invalid_y = yy[arr.mask]
151-
invalid_arr = rbf(np.array([invalid_x, invalid_y]).T)
152-
# fill
153-
arr[arr.mask] = invalid_arr
154-
return arr.data
155-
156-
# make sure nan are used behind masked pixels (and not 0)
157-
lat = sanitize_data(lat, jnp.nan, mask)
158-
lon = sanitize_data(lon, jnp.nan, mask)
159-
# fill nan using RBF interpolation
160-
lat = fill_nan(lat)
161-
lon = fill_nan(lon)
162-
return lat, lon

0 commit comments

Comments
 (0)