Skip to content

Commit e5b6c37

Browse files
committed
Add support for interpolation with 3D altitude
This commit adds support for interpolating datasets of the form (lon, lat, z), with (lon, lat) being 1D variables, and z being z(lon, lat) (a 3D variable). This is accomplished by splitting interpolation into two: first, we perform linear interpolation along the vertical direction, and the bilinear in the horizontal. More specifically, I interpolate the nodal points around the target point onto the same altitude. Then, I perform bilinear interpolation with the interpolated values. This code is hacky and not polished. It runs on GPUs only because we allow scalars. Boundary conditions are not implemented very well and could be imprecise.
1 parent 774e892 commit e5b6c37

File tree

3 files changed

+314
-10
lines changed

3 files changed

+314
-10
lines changed

ext/InterpolationsRegridderExt.jl

Lines changed: 176 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,189 @@ Regrid the given data as defined on the given dimensions to the `target_space` i
8080
This function is allocating.
8181
"""
8282
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
83+
# TODO: There is room for improvement in this function...
84+
8385
FT = ClimaCore.Spaces.undertype(regridder.target_space)
8486
dimensions_FT = map(d -> FT.(d), dimensions)
8587

86-
# Make a linear spline
87-
itp = Intp.extrapolate(
88-
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
89-
regridder.extrapolation_bc,
88+
coordinates = ClimaCore.Fields.coordinate_field(regridder.target_space)
89+
device = ClimaComms.device(regridder.target_space)
90+
91+
has_3d_z = length(size(last(dimensions))) == 3
92+
if eltype(coordinates) <: ClimaCore.Geometry.LatLongZPoint && has_3d_z
93+
# If we have 3D altitudes, we do linear in the vertical and bilinear
94+
# horizontal separately
95+
@warn "Ignoring boundary conditions, implementing Periodic, Flat, Flat"
96+
97+
return map(regridder.coordinates) do coord
98+
ClimaComms.allowscalar(
99+
interpolation_3d_z,
100+
device,
101+
data,
102+
dimensions_FT...,
103+
totuple(coord)...,
104+
)
105+
end
106+
else
107+
# Make a linear spline
108+
itp = Intp.extrapolate(
109+
Intp.interpolate(
110+
dimensions_FT,
111+
FT.(data),
112+
Intp.Gridded(Intp.Linear()),
113+
),
114+
regridder.extrapolation_bc,
115+
)
116+
117+
# Move it to GPU (if needed)
118+
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
119+
120+
return map(regridder.coordinates) do coord
121+
gpuitp(totuple(coord)...)
122+
end
123+
end
124+
end
125+
126+
"""
127+
interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
128+
129+
Perform bilinear + vertical interpolation on a 3D dataset.
130+
131+
This function first performs linear interpolation along the z-axis at the four
132+
corners of the cell containing the target (x, y) point. Then, it performs
133+
bilinear interpolation in the x-y plane using the z-interpolated values.
134+
135+
Periodic is implemented on the x direction, Flat on the other ones.
136+
137+
# Arguments
138+
- `data`: A 3D array of data values.
139+
- `xs`: A vector of x-coordinates corresponding to the first dimension of `data`.
140+
- `ys`: A vector of y-coordinates corresponding to the second dimension of `data`.
141+
- `zs`: A 3D array of z-coordinates. `zs[i, j, :]` provides the z-coordinates for the data point `data[i, j, :]`.
142+
- `target_x`: The x-coordinate of the target point.
143+
- `target_y`: The y-coordinate of the target point.
144+
- `target_z`: The z-coordinate of the target point.
145+
"""
146+
function interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
147+
# Check boundaries
148+
# if target_x < xs[begin] || target_x > xs[end]
149+
# error(
150+
# "target_x is out of bounds: $(target_x) not in [$(xs[1]), $(xs[end])]",
151+
# )
152+
# end
153+
# if target_y < ys[begin] || target_y > ys[end]
154+
# error(
155+
# "target_y is out of bounds: $(target_y) not in [$(ys[1]), $(ys[end])]",
156+
# )
157+
# end
158+
159+
# Find nearest neighbors
160+
target_x = mod(target_x, maximum(xs) - minimum(xs))
161+
162+
x_index = searchsortedfirst(xs, target_x)
163+
y_index = searchsortedfirst(ys, target_y)
164+
165+
x0_index = x_index == 1 ? x_index : x_index - 1
166+
x1_index = x0_index + 1
167+
168+
y0_index = y_index == 1 ? y_index : y_index - 1
169+
# Flat
170+
y0_index = clamp(y0_index, 1, length(ys) - 1)
171+
y1_index = y0_index + 1
172+
if y0_index == 1 && y0_index == length(ys) - 1
173+
target_y = ys[y0_index]
174+
end
175+
176+
# Interpolate in z-direction
177+
z00 = zs[x0_index, y0_index, :]
178+
z01 = zs[x0_index, y1_index, :]
179+
z10 = zs[x1_index, y0_index, :]
180+
z11 = zs[x1_index, y1_index, :]
181+
182+
f00 = linear_interp_z(data[x0_index, y0_index, :], z00, target_z)
183+
f01 = linear_interp_z(data[x0_index, y1_index, :], z01, target_z)
184+
f10 = linear_interp_z(data[x1_index, y0_index, :], z10, target_z)
185+
f11 = linear_interp_z(data[x1_index, y1_index, :], z11, target_z)
186+
187+
# Bilinear interpolation in x-y plane
188+
val = bilinear_interp(
189+
f00,
190+
f01,
191+
f10,
192+
f11,
193+
xs[x0_index],
194+
xs[x1_index],
195+
ys[y0_index],
196+
ys[y1_index],
197+
target_x,
198+
target_y,
90199
)
200+
return val
201+
end
202+
203+
"""
204+
linear_interp_z(f, z, target_z)
91205
92-
# Move it to GPU (if needed)
93-
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
206+
Perform linear interpolation along the z-axis.
94207
95-
return map(regridder.coordinates) do coord
96-
gpuitp(totuple(coord)...)
208+
# Arguments
209+
- `f`: A vector of function values corresponding to the z-coordinates in `z`.
210+
- `z`: A vector of z-coordinates.
211+
- `target_z`: The z-coordinate at which to interpolate.
212+
213+
# Returns
214+
The linearly interpolated value at `target_z`.
215+
"""
216+
function linear_interp_z(f, z, target_z)
217+
# if target_z < z[begin] || target_z > z[end]
218+
# error(
219+
# "target_z is out of bounds: $(target_z) not in [$(z[1]), $(z[end])]",
220+
# )
221+
# end
222+
223+
index = searchsortedfirst(z, target_z)
224+
# Handle edge cases for index
225+
# Flat
226+
if index == 1
227+
z0 = z[index]
228+
z1 = z[index + 1]
229+
f0 = f[index]
230+
f1 = f[index + 1]
231+
else
232+
z0 = z[index - 1]
233+
z1 = z[index]
234+
f0 = f[index - 1]
235+
f1 = f[index]
97236
end
237+
238+
return f0 + (target_z - z0) / (z1 - z0) * (f1 - f0)
239+
end
240+
241+
"""
242+
bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
243+
244+
Perform bilinear interpolation on a 2D plane.
245+
246+
# Arguments
247+
- `f00`: Function value at (x0, y0).
248+
- `f01`: Function value at (x0, y1).
249+
- `f10`: Function value at (x1, y0).
250+
- `f11`: Function value at (x1, y1).
251+
- `x0`: x-coordinate of the first corner.
252+
- `x1`: x-coordinate of the second corner.
253+
- `y0`: y-coordinate of the first corner.
254+
- `y1`: y-coordinate of the second corner.
255+
- `target_x`: The x-coordinate of the target point.
256+
- `target_y`: The y-coordinate of the target point.
257+
"""
258+
function bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
259+
val = (
260+
(x1 - target_x) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f00 +
261+
(x1 - target_x) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f01 +
262+
(target_x - x0) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f10 +
263+
(target_x - x0) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f11
264+
)
265+
return val
98266
end
99267

100268
end

test/TestTools.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import ClimaComms
33
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
44

55
function make_spherical_space(FT; context = ClimaComms.context())
6-
radius = FT(128)
7-
zlim = (FT(0), FT(1))
6+
radius = FT(6300e3)
7+
zlim = (FT(0), FT(10000))
88
helem = 4
99
zelem = 10
1010
Nq = 4

test/interpolations_regridder.jl

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
using Test
2+
import ClimaUtilities
3+
import Interpolations
4+
import ClimaUtilities: Regridders
5+
import ClimaComms
6+
import ClimaCore
7+
8+
linear_interp_z =
9+
Base.get_extension(
10+
ClimaUtilities,
11+
:ClimaUtilitiesClimaCoreInterpolationsExt,
12+
).InterpolationsRegridderExt.linear_interp_z
13+
bilinear_interp =
14+
Base.get_extension(
15+
ClimaUtilities,
16+
:ClimaUtilitiesClimaCoreInterpolationsExt,
17+
).InterpolationsRegridderExt.bilinear_interp
18+
interpolation_3d_z =
19+
Base.get_extension(
20+
ClimaUtilities,
21+
:ClimaUtilitiesClimaCoreInterpolationsExt,
22+
).InterpolationsRegridderExt.interpolation_3d_z
23+
24+
const context = ClimaComms.context()
25+
ClimaComms.init(context)
26+
27+
include("TestTools.jl")
28+
29+
@testset "Interpolation Tests" begin
30+
@testset "linear_interp_z" begin
31+
f = [1.0, 3.0, 5.0]
32+
z = [10.0, 20.0, 30.0]
33+
@test linear_interp_z(f, z, 15.0) 2.0
34+
@test linear_interp_z(f, z, 25.0) 4.0
35+
@test linear_interp_z(f, z, 10.0) 1.0
36+
@test linear_interp_z(f, z, 30.0) 5.0
37+
38+
# Out of bounds
39+
f = [1.0, 3.0]
40+
z = [10.0, 20.0]
41+
@test_throws ErrorException linear_interp_z(f, z, 5.0)
42+
@test_throws ErrorException linear_interp_z(f, z, 25.0)
43+
44+
45+
# One point
46+
f = [2.5]
47+
z = [15.0]
48+
@test_throws ErrorException linear_interp_z(f, z, 10.0)
49+
@test_throws ErrorException linear_interp_z(f, z, 20.0)
50+
51+
# Non uniform spacing
52+
f = [2.0, 4.0, 8.0]
53+
z = [1.0, 3.0, 7.0]
54+
@test linear_interp_z(f, z, 1.0) 2.0
55+
@test linear_interp_z(f, z, 3.0) 4.0
56+
@test linear_interp_z(f, z, 7.0) 8.0
57+
@test linear_interp_z(f, z, 2.0) 3.0
58+
@test linear_interp_z(f, z, 5.0) 6.0
59+
end
60+
61+
@testset "interpolation_3d_z" begin
62+
# Test cases for the main 3D interpolation function
63+
64+
# Create some sample data
65+
xs = [1.0, 2.0, 3.0]
66+
ys = [4.0, 5.0, 6.0]
67+
zs = zeros(3, 3, 8)
68+
for k in 1:8
69+
for j in 1:3
70+
for i in 1:3
71+
zs[i, j, k] = i + j + k
72+
end
73+
end
74+
end
75+
data = reshape(1:(3 * 3 * 8), 3, 3, 8)
76+
77+
# Exact point
78+
@test interpolation_3d_z(data, xs, ys, zs, xs[1], ys[1], zs[1, 1, 4])
79+
data[1, 1, 4]
80+
81+
# Interpolated point
82+
@test interpolation_3d_z(data, xs, ys, zs, 2.5, 5.5, 7.5) 20.5
83+
84+
# Out of bounds
85+
@test_throws ErrorException interpolation_3d_z(
86+
data,
87+
xs,
88+
ys,
89+
zs,
90+
0.5,
91+
4.5,
92+
3.5,
93+
)
94+
@test_throws ErrorException interpolation_3d_z(
95+
data,
96+
xs,
97+
ys,
98+
zs,
99+
2.5,
100+
4.5,
101+
4.5,
102+
)
103+
end
104+
105+
@testset "Regrid" begin
106+
107+
lon, lat, z =
108+
collect(-180.0:1:180), collect(-90.0:1:90), collect(0.0:1.0:100.0)
109+
size3D = (361, 181, 101)
110+
data_z3D = zeros(size3D)
111+
112+
for i in 1:length(lon)
113+
for j in 1:length(lat)
114+
data_z3D[i, j, :] .= z
115+
end
116+
end
117+
dimensions3D = (lon, lat, data_z3D)
118+
119+
FT = Float64
120+
spaces = make_spherical_space(FT; context)
121+
hv_center_space = spaces.hybrid
122+
extrapolation_bc = (
123+
Interpolations.Throw(),
124+
Interpolations.Throw(),
125+
Interpolations.Throw(),
126+
)
127+
reg_hv = Regridders.InterpolationsRegridder(
128+
hv_center_space;
129+
extrapolation_bc,
130+
)
131+
regridded_z = Regridders.regrid(reg_hv, data_z3D, dimensions3D)
132+
@test maximum(ClimaCore.Fields.level(regridded_z, 2)) 0.15
133+
@test minimum(ClimaCore.Fields.level(regridded_z, 2)) 0.15
134+
135+
end
136+
end

0 commit comments

Comments
 (0)