|
4 | 4 | import jax.numpy as jnp |
5 | 5 | from jaxtyping import Array, Float |
6 | 6 |
|
7 | | -from .sanitize import handle_land_boundary |
8 | | - |
9 | 7 |
|
10 | 8 | def interpolation( |
11 | 9 | field: Float[Array, "lat lon"], |
@@ -40,44 +38,34 @@ def interpolation( |
40 | 38 | field : Float[Array, "lat lon"] |
41 | 39 | Interpolated field |
42 | 40 | """ |
43 | | - def do_interpolate(field_b, field_f, mask_b, mask_f, pad_left): |
44 | | - field_b, field_f = handle_land_boundary(field_b, field_f, mask_b, mask_f, pad_left) |
45 | | - return 0.5 * (field_b + field_f) |
46 | | - |
47 | | - def axis0(pad_left): |
48 | | - field_b, field_f = field[:-1, :], field[1:, :] |
49 | | - mask_b, mask_f = mask[:-1, :], mask[1:, :] |
50 | | - midpoint_values = do_interpolate(field_b, field_f, mask_b, mask_f, pad_left) |
51 | | - |
52 | | - arr = lax.cond( |
53 | | - pad_left, |
54 | | - lambda: jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge"), |
55 | | - lambda: jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge") |
56 | | - ) |
57 | | - |
58 | | - return arr |
59 | | - |
60 | | - def axis1(pad_left): |
61 | | - field_b, field_f = field[:, :-1], field[:, 1:] |
62 | | - mask_b, mask_f = mask[:, :-1], mask[:, 1:] |
63 | | - midpoint_values = do_interpolate(field_b, field_f, mask_b, mask_f, pad_left) |
64 | | - |
65 | | - arr = lax.cond( |
66 | | - pad_left, |
67 | | - lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge"), |
68 | | - lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge") |
69 | | - ) |
70 | | - |
71 | | - return arr |
72 | | - |
73 | | - field = lax.cond( |
74 | | - axis == 0, |
75 | | - lambda pad_left: axis0(pad_left), |
76 | | - lambda pad_left: axis1(pad_left), |
77 | | - padding == "left" |
| 41 | + f = jnp.moveaxis(field, axis, -1) |
| 42 | + |
| 43 | + mid = (f[:, :-1] + f[:, 1:]) * 0.5 |
| 44 | + |
| 45 | + # handle mask: extrapolate at land boundaries (up to 1 cell) |
| 46 | + mid = jnp.where( |
| 47 | + jnp.isnan(mid), |
| 48 | + f[:, :-1], |
| 49 | + mid |
| 50 | + ) |
| 51 | + mid = jnp.where( |
| 52 | + jnp.isnan(mid), |
| 53 | + f[:, 1:], |
| 54 | + mid |
| 55 | + ) |
| 56 | + |
| 57 | + # extrapolate at the domain boundary |
| 58 | + mid = lax.cond( |
| 59 | + padding == "left", |
| 60 | + lambda: jnp.concatenate([f[:, :1], mid], axis=-1), |
| 61 | + lambda: jnp.concatenate([mid, f[:, -1:]], axis=-1) |
78 | 62 | ) |
79 | 63 |
|
80 | | - return field |
| 64 | + mid = jnp.moveaxis(mid, -1, axis) |
| 65 | + |
| 66 | + mid = jnp.where(mask, jnp.nan, mid) |
| 67 | + |
| 68 | + return mid |
81 | 69 |
|
82 | 70 |
|
83 | 71 | def derivative( |
@@ -116,41 +104,31 @@ def derivative( |
116 | 104 | field : Float[Array, "lat lon"] |
117 | 105 | Interpolated field |
118 | 106 | """ |
119 | | - def do_differentiate(field_b, field_f, mask_b, mask_f, pad_left): |
120 | | - field_b, field_f = handle_land_boundary(field_b, field_f, mask_b, mask_f, pad_left) |
121 | | - return field_f - field_b |
122 | | - |
123 | | - def axis0(pad_left): |
124 | | - field_b, field_f = field[:-1, :], field[1:, :] |
125 | | - mask_b, mask_f = mask[:-1, :], mask[1:, :] |
126 | | - midpoint_values = do_differentiate(field_b, field_f, mask_b, mask_f, pad_left) |
127 | | - |
128 | | - arr = lax.cond( |
129 | | - pad_left, |
130 | | - lambda: jnp.pad(midpoint_values, pad_width=((1, 0), (0, 0)), mode="edge"), |
131 | | - lambda: jnp.pad(midpoint_values, pad_width=((0, 1), (0, 0)), mode="edge") |
132 | | - ) |
133 | | - |
134 | | - return arr |
135 | | - |
136 | | - def axis1(pad_left): |
137 | | - field_b, field_f = field[:, :-1], field[:, 1:] |
138 | | - mask_b, mask_f = mask[:, :-1], mask[:, 1:] |
139 | | - midpoint_values = do_differentiate(field_b, field_f, mask_b, mask_f, pad_left) |
140 | | - |
141 | | - arr = lax.cond( |
142 | | - pad_left, |
143 | | - lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (1, 0)), mode="edge"), |
144 | | - lambda: jnp.pad(midpoint_values, pad_width=((0, 0), (0, 1)), mode="edge") |
145 | | - ) |
146 | | - |
147 | | - return arr |
148 | | - |
149 | | - field = lax.cond( |
150 | | - axis == 0, |
151 | | - lambda pad_left: axis0(pad_left), |
152 | | - lambda pad_left: axis1(pad_left), |
153 | | - padding == "left" |
| 107 | + f = jnp.moveaxis(field, axis, -1) |
| 108 | + |
| 109 | + mid = jnp.diff(f, axis=-1) |
| 110 | + |
| 111 | + # handle mask: extrapolate at land boundaries (up to 1 cell) |
| 112 | + mid = jnp.where( |
| 113 | + jnp.isnan(mid), |
| 114 | + jnp.pad(mid[:, 1:], pad_width=((0, 0), (0, 1)), mode="edge"), |
| 115 | + mid |
154 | 116 | ) |
| 117 | + mid = jnp.where( |
| 118 | + jnp.isnan(mid), |
| 119 | + jnp.pad(mid[:, :-1], pad_width=((0, 0), (1, 0)), mode="edge"), |
| 120 | + mid |
| 121 | + ) |
| 122 | + |
| 123 | + # extrapolate at the domain boundary |
| 124 | + mid = lax.cond( |
| 125 | + padding == "left", |
| 126 | + lambda: jnp.pad(mid, pad_width=((0, 0), (1, 0)), mode="edge"), |
| 127 | + lambda: jnp.pad(mid, pad_width=((0, 0), (0, 1)), mode="edge" ) |
| 128 | + ) |
| 129 | + |
| 130 | + mid = jnp.moveaxis(mid, -1, axis) |
| 131 | + |
| 132 | + mid = jnp.where(mask, jnp.nan, mid) |
155 | 133 |
|
156 | | - return field / dxy |
| 134 | + return mid / dxy |
0 commit comments