Skip to content

Commit 807a28d

Browse files
committed
Add support for interpolation with 3D altitude
This commit adds support for interpolating datasets of the form (lon, lat, z), with (lon, lat) being 1D variables, and z being z(lon, lat) (a 3D variable). This is accomplished by splitting interpolation into two: first, we perform linear interpolation along the vertical direction, and the bilinear in the horizontal. More specifically, I interpolate the nodal points around the target point onto the same altitude. Then, I perform bilinear interpolation with the interpolated values. This code is hacky and not polished. It runs on GPUs only because we allow scalars. Boundary conditions are not implemented very well and could be imprecise.
1 parent 774e892 commit 807a28d

File tree

3 files changed

+334
-10
lines changed

3 files changed

+334
-10
lines changed

docs/src/inputs.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ Read more about this feature in the page about [`DataHandler`](@ref datahandling
117117
default, the `Throw` condition is used, meaning that interpolating onto a point
118118
that is outside the range of definition of the data is not allowed. Other
119119
boundary conditions are allowed. With the `Flat` boundary condition, when
120-
interpolating outside of the range of definition, return the value of the
121-
of closest boundary is used instead.
120+
interpolating outside of the range of definition, the value of the closest
121+
boundary is used instead.
122122

123123
Another boundary condition that is often useful is `PeriodicCalendar`, which
124124
repeats data over and over.

ext/InterpolationsRegridderExt.jl

Lines changed: 196 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,209 @@ Regrid the given data as defined on the given dimensions to the `target_space` i
8080
This function is allocating.
8181
"""
8282
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
83+
# TODO: There is room for improvement in this function...
84+
8385
FT = ClimaCore.Spaces.undertype(regridder.target_space)
8486
dimensions_FT = map(d -> FT.(d), dimensions)
8587

86-
# Make a linear spline
87-
itp = Intp.extrapolate(
88-
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
89-
regridder.extrapolation_bc,
88+
coordinates = ClimaCore.Fields.coordinate_field(regridder.target_space)
89+
device = ClimaComms.device(regridder.target_space)
90+
91+
has_3d_z = length(size(last(dimensions))) == 3
92+
if eltype(coordinates) <: ClimaCore.Geometry.LatLongZPoint && has_3d_z
93+
# If we have 3D altitudes, we do linear in the vertical and bilinear
94+
# horizontal separately
95+
@warn "Ignoring boundary conditions, implementing Periodic, Flat, Flat"
96+
97+
adapted_data = Adapt.adapt(ClimaComms.array_type(regridder.target_space), data)
98+
xs, ys, zs = dimensions_FT
99+
adapted_xs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), xs)
100+
adapted_ys = Adapt.adapt(ClimaComms.array_type(regridder.target_space), ys)
101+
adapted_zs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), zs)
102+
103+
return ClimaComms.allowscalar(ClimaComms.device(regridder.target_space)) do
104+
map(regridder.coordinates) do coord
105+
interpolation_3d_z(
106+
adapted_data,
107+
adapted_xs, adapted_ys, adapted_zs,
108+
totuple(coord)...,
109+
)
110+
end
111+
end
112+
else
113+
# Make a linear spline
114+
itp = Intp.extrapolate(
115+
Intp.interpolate(
116+
dimensions_FT,
117+
FT.(data),
118+
Intp.Gridded(Intp.Linear()),
119+
),
120+
regridder.extrapolation_bc,
121+
)
122+
123+
# Move it to GPU (if needed)
124+
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
125+
126+
return map(regridder.coordinates) do coord
127+
gpuitp(totuple(coord)...)
128+
end
129+
end
130+
end
131+
132+
"""
133+
interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
134+
135+
Perform bilinear + vertical interpolation on a 3D dataset.
136+
137+
This function first performs linear interpolation along the z-axis at the four
138+
corners of the cell containing the target (x, y) point. Then, it performs
139+
bilinear interpolation in the x-y plane using the z-interpolated values.
140+
141+
Periodic is implemented on the x direction, Flat on the other ones.
142+
143+
# Arguments
144+
- `data`: A 3D array of data values.
145+
- `xs`: A vector of x-coordinates corresponding to the first dimension of `data`.
146+
- `ys`: A vector of y-coordinates corresponding to the second dimension of `data`.
147+
- `zs`: A 3D array of z-coordinates. `zs[i, j, :]` provides the z-coordinates for the data point `data[i, j, :]`.
148+
- `target_x`: The x-coordinate of the target point.
149+
- `target_y`: The y-coordinate of the target point.
150+
- `target_z`: The z-coordinate of the target point.
151+
"""
152+
function interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
153+
# Check boundaries
154+
# if target_x < xs[begin] || target_x > xs[end]
155+
# error(
156+
# "target_x is out of bounds: $(target_x) not in [$(xs[1]), $(xs[end])]",
157+
# )
158+
# end
159+
# if target_y < ys[begin] || target_y > ys[end]
160+
# error(
161+
# "target_y is out of bounds: $(target_y) not in [$(ys[1]), $(ys[end])]",
162+
# )
163+
# end
164+
165+
# Find nearest neighbors
166+
x_period = xs[end] - xs[begin]
167+
target_x = mod(target_x, x_period)
168+
169+
x_index = searchsortedfirst(xs, target_x)
170+
y_index = searchsortedfirst(ys, target_y)
171+
172+
x0_index = x_index == 1 ? x_index : x_index - 1
173+
x1_index = x0_index + 1
174+
175+
y0_index = y_index == 1 ? y_index : y_index - 1
176+
# Flat
177+
y0_index = clamp(y0_index, 1, length(ys) - 1)
178+
y1_index = y0_index + 1
179+
if y0_index == 1
180+
target_y = ys[y0_index]
181+
end
182+
if y1_index == length(ys)
183+
target_y = ys[y1_index]
184+
end
185+
186+
187+
# Interpolate in z-direction
188+
189+
z00 = @view zs[x0_index, y0_index, :]
190+
z01 = @view zs[x0_index, y1_index, :]
191+
z10 = @view zs[x1_index, y0_index, :]
192+
z11 = @view zs[x1_index, y1_index, :]
193+
194+
f00 = linear_interp_z(view(data,x0_index, y0_index, :), z00, target_z)
195+
f01 = linear_interp_z(view(data,x0_index, y1_index, :), z01, target_z)
196+
f10 = linear_interp_z(view(data,x1_index, y0_index, :), z10, target_z)
197+
f11 = linear_interp_z(view(data,x1_index, y1_index, :), z11, target_z)
198+
199+
# Bilinear interpolation in x-y plane
200+
val = bilinear_interp(
201+
f00,
202+
f01,
203+
f10,
204+
f11,
205+
xs[x0_index],
206+
xs[x1_index],
207+
ys[y0_index],
208+
ys[y1_index],
209+
target_x,
210+
target_y,
90211
)
91212

92-
# Move it to GPU (if needed)
93-
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
213+
return val
214+
end
215+
216+
"""
217+
linear_interp_z(f, z, target_z)
218+
219+
Perform linear interpolation along the z-axis.
220+
221+
# Arguments
222+
- `f`: A vector of function values corresponding to the z-coordinates in `z`.
223+
- `z`: A vector of z-coordinates.
224+
- `target_z`: The z-coordinate at which to interpolate.
225+
226+
# Returns
227+
The linearly interpolated value at `target_z`.
228+
"""
229+
function linear_interp_z(f, z, target_z)
230+
# if target_z < z[begin] || target_z > z[end]
231+
# error(
232+
# "target_z is out of bounds: $(target_z) not in [$(z[1]), $(z[end])]",
233+
# )
234+
# end
235+
236+
index = searchsortedfirst(z, target_z)
237+
# Handle edge cases for index
238+
# Flat
239+
if index == 1
240+
z0 = z[index]
241+
z1 = z[index + 1]
242+
f0 = f[index]
243+
f1 = f[index + 1]
244+
else
245+
z0 = z[index - 1]
246+
z1 = z[index]
247+
f0 = f[index - 1]
248+
f1 = f[index]
249+
end
94250

95-
return map(regridder.coordinates) do coord
96-
gpuitp(totuple(coord)...)
251+
if index == 1
252+
target_z = z[index]
97253
end
254+
if index == length(z) - 1
255+
target_z = z[index + 1]
256+
end
257+
val = f0 + (target_z - z0) / (z1 - z0) * (f1 - f0)
258+
return val
259+
end
260+
261+
"""
262+
bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
263+
264+
Perform bilinear interpolation on a 2D plane.
265+
266+
# Arguments
267+
- `f00`: Function value at (x0, y0).
268+
- `f01`: Function value at (x0, y1).
269+
- `f10`: Function value at (x1, y0).
270+
- `f11`: Function value at (x1, y1).
271+
- `x0`: x-coordinate of the first corner.
272+
- `x1`: x-coordinate of the second corner.
273+
- `y0`: y-coordinate of the first corner.
274+
- `y1`: y-coordinate of the second corner.
275+
- `target_x`: The x-coordinate of the target point.
276+
- `target_y`: The y-coordinate of the target point.
277+
"""
278+
function bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
279+
val = (
280+
(x1 - target_x) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f00 +
281+
(x1 - target_x) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f01 +
282+
(target_x - x0) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f10 +
283+
(target_x - x0) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f11
284+
)
285+
return val
98286
end
99287

100288
end

test/interpolations_regridder.jl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
using Test
2+
import ClimaUtilities
3+
import Interpolations
4+
import ClimaUtilities: Regridders
5+
import ClimaComms
6+
import ClimaCore
7+
8+
linear_interp_z =
9+
Base.get_extension(
10+
ClimaUtilities,
11+
:ClimaUtilitiesClimaCoreInterpolationsExt,
12+
).InterpolationsRegridderExt.linear_interp_z
13+
bilinear_interp =
14+
Base.get_extension(
15+
ClimaUtilities,
16+
:ClimaUtilitiesClimaCoreInterpolationsExt,
17+
).InterpolationsRegridderExt.bilinear_interp
18+
interpolation_3d_z =
19+
Base.get_extension(
20+
ClimaUtilities,
21+
:ClimaUtilitiesClimaCoreInterpolationsExt,
22+
).InterpolationsRegridderExt.interpolation_3d_z
23+
24+
const context = ClimaComms.context()
25+
ClimaComms.init(context)
26+
27+
include("TestTools.jl")
28+
29+
@testset "Interpolation Tests" begin
30+
@testset "linear_interp_z" begin
31+
f = [1.0, 3.0, 5.0]
32+
z = [10.0, 20.0, 30.0]
33+
@test linear_interp_z(f, z, 15.0) 2.0
34+
@test linear_interp_z(f, z, 25.0) 4.0
35+
@test linear_interp_z(f, z, 10.0) 1.0
36+
@test linear_interp_z(f, z, 30.0) 5.0
37+
38+
# Out of bounds
39+
f = [1.0, 3.0]
40+
z = [10.0, 20.0]
41+
@test_throws ErrorException linear_interp_z(f, z, 5.0)
42+
@test_throws ErrorException linear_interp_z(f, z, 25.0)
43+
44+
45+
# One point
46+
f = [2.5]
47+
z = [15.0]
48+
@test_throws ErrorException linear_interp_z(f, z, 10.0)
49+
@test_throws ErrorException linear_interp_z(f, z, 20.0)
50+
51+
# Non uniform spacing
52+
f = [2.0, 4.0, 8.0]
53+
z = [1.0, 3.0, 7.0]
54+
@test linear_interp_z(f, z, 1.0) 2.0
55+
@test linear_interp_z(f, z, 3.0) 4.0
56+
@test linear_interp_z(f, z, 7.0) 8.0
57+
@test linear_interp_z(f, z, 2.0) 3.0
58+
@test linear_interp_z(f, z, 5.0) 6.0
59+
end
60+
61+
@testset "interpolation_3d_z" begin
62+
# Test cases for the main 3D interpolation function
63+
64+
# Create some sample data
65+
xs = [1.0, 2.0, 3.0]
66+
ys = [4.0, 5.0, 6.0]
67+
zs = zeros(3, 3, 8)
68+
for k in 1:8
69+
for j in 1:3
70+
for i in 1:3
71+
zs[i, j, k] = i + j + k
72+
end
73+
end
74+
end
75+
data = reshape(1:(3 * 3 * 8), 3, 3, 8)
76+
77+
# Exact point
78+
@test interpolation_3d_z(data, xs, ys, zs, xs[1], ys[1], zs[1, 1, 4])
79+
data[1, 1, 4]
80+
81+
# Interpolated point
82+
@test interpolation_3d_z(data, xs, ys, zs, 2.5, 5.5, 7.5) 20.5
83+
84+
# Out of bounds
85+
@test_throws ErrorException interpolation_3d_z(
86+
data,
87+
xs,
88+
ys,
89+
zs,
90+
0.5,
91+
4.5,
92+
3.5,
93+
)
94+
@test_throws ErrorException interpolation_3d_z(
95+
data,
96+
xs,
97+
ys,
98+
zs,
99+
2.5,
100+
4.5,
101+
4.5,
102+
)
103+
end
104+
105+
@testset "Regrid" begin
106+
107+
lon, lat, z =
108+
collect(-180.0:1:180), collect(-90.0:1:90), collect(0.0:1.0:100.0)
109+
size3D = (361, 181, 101)
110+
data_z3D = zeros(size3D)
111+
112+
for i in 1:length(lon)
113+
for j in 1:length(lat)
114+
data_z3D[i, j, :] .= z
115+
end
116+
end
117+
dimensions3D = (lon, lat, data_z3D)
118+
119+
FT = Float64
120+
spaces = make_spherical_space(FT; context)
121+
hv_center_space = spaces.hybrid
122+
extrapolation_bc = (
123+
Interpolations.Throw(),
124+
Interpolations.Throw(),
125+
Interpolations.Throw(),
126+
)
127+
reg_hv = Regridders.InterpolationsRegridder(
128+
hv_center_space;
129+
extrapolation_bc,
130+
)
131+
regridded_z = Regridders.regrid(reg_hv, data_z3D, dimensions3D)
132+
@test maximum(ClimaCore.Fields.level(regridded_z, 2)) 0.15
133+
@test minimum(ClimaCore.Fields.level(regridded_z, 2)) 0.15
134+
135+
end
136+
end

0 commit comments

Comments
 (0)