Skip to content

Commit 921d9d1

Browse files
committed
Add support for interpolation with 3D altitude
This commit adds support for interpolating datasets of the form (lon, lat, z), with (lon, lat) being 1D variables, and z being z(lon, lat) (a 3D variable). This is accomplished by splitting interpolation into two: first, we perform linear interpolation along the vertical direction, and the bilinear in the horizontal. More specifically, I interpolate the nodal points around the target point onto the same altitude. Then, I perform bilinear interpolation with the interpolated values. This code is not polished and does not run on GPU. I am not very happy about how it looks, but it does the job
1 parent 774e892 commit 921d9d1

File tree

2 files changed

+266
-8
lines changed

2 files changed

+266
-8
lines changed

ext/InterpolationsRegridderExt.jl

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,21 +80,172 @@ Regrid the given data as defined on the given dimensions to the `target_space` i
8080
This function is allocating.
8181
"""
8282
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
83+
# TODO: There is room for improvement in this function...
84+
8385
FT = ClimaCore.Spaces.undertype(regridder.target_space)
8486
dimensions_FT = map(d -> FT.(d), dimensions)
8587

86-
# Make a linear spline
87-
itp = Intp.extrapolate(
88-
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
89-
regridder.extrapolation_bc,
88+
coordinates = ClimaCore.Fields.coordinate_field(regridder.target_space)
89+
90+
has_3d_z = length(size(last(dimensions))) == 3
91+
if eltype(coordinates) <: ClimaCore.Geometry.LatLongZPoint && has_3d_z
92+
# If we have 3D altitudes, we do linear in the vertical and bilinear
93+
# horizontal separately
94+
95+
# TODO: Support other boundary conditions
96+
any(x -> x != Intp.Throw(), regridder.extrapolation_bc) && error(
97+
"Only Throw() boundary condition currently implemented $(regridder.extrapolation_bc)",
98+
)
99+
return map(regridder.coordinates) do coord
100+
interpolation_3d_z(data, dimensions_FT..., totuple(coord)...)
101+
end
102+
else
103+
# Make a linear spline
104+
itp = Intp.extrapolate(
105+
Intp.interpolate(
106+
dimensions_FT,
107+
FT.(data),
108+
Intp.Gridded(Intp.Linear()),
109+
),
110+
regridder.extrapolation_bc,
111+
)
112+
113+
# Move it to GPU (if needed)
114+
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
115+
116+
return map(regridder.coordinates) do coord
117+
gpuitp(totuple(coord)...)
118+
end
119+
end
120+
end
121+
122+
"""
123+
interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
124+
125+
Perform bilinear + vertical interpolation on a 3D dataset.
126+
127+
This function first performs linear interpolation along the z-axis at the four
128+
corners of the cell containing the target (x, y) point. Then, it performs
129+
bilinear interpolation in the x-y plane using the z-interpolated values.
130+
131+
# Arguments
132+
- `data`: A 3D array of data values.
133+
- `xs`: A vector of x-coordinates corresponding to the first dimension of `data`.
134+
- `ys`: A vector of y-coordinates corresponding to the second dimension of `data`.
135+
- `zs`: A 3D array of z-coordinates. `zs[i, j, :]` provides the z-coordinates for the data point `data[i, j, :]`.
136+
- `target_x`: The x-coordinate of the target point.
137+
- `target_y`: The y-coordinate of the target point.
138+
- `target_z`: The z-coordinate of the target point.
139+
"""
140+
function interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
141+
# Check boundaries
142+
if target_x < xs[begin] || target_x > xs[end]
143+
error(
144+
"target_x is out of bounds: $(target_x) not in [$(xs[1]), $(xs[end])]",
145+
)
146+
end
147+
if target_y < ys[begin] || target_y > ys[end]
148+
error(
149+
"target_y is out of bounds: $(target_y) not in [$(ys[1]), $(ys[end])]",
150+
)
151+
end
152+
153+
# Find nearest neighbors
154+
x_index = searchsortedfirst(xs, target_x)
155+
y_index = searchsortedfirst(ys, target_y)
156+
157+
x0_index = x_index == 1 ? x_index : x_index - 1
158+
x1_index = x0_index + 1
159+
y0_index = y_index == 1 ? y_index : y_index - 1
160+
y1_index = y0_index + 1
161+
162+
# Interpolate in z-direction
163+
z00 = zs[x0_index, y0_index, :]
164+
z01 = zs[x0_index, y1_index, :]
165+
z10 = zs[x1_index, y0_index, :]
166+
z11 = zs[x1_index, y1_index, :]
167+
168+
f00 = linear_interp_z(data[x0_index, y0_index, :], z00, target_z)
169+
f01 = linear_interp_z(data[x0_index, y1_index, :], z01, target_z)
170+
f10 = linear_interp_z(data[x1_index, y0_index, :], z10, target_z)
171+
f11 = linear_interp_z(data[x1_index, y1_index, :], z11, target_z)
172+
173+
# Bilinear interpolation in x-y plane
174+
return bilinear_interp(
175+
f00,
176+
f01,
177+
f10,
178+
f11,
179+
xs[x0_index],
180+
xs[x1_index],
181+
ys[y0_index],
182+
ys[y1_index],
183+
target_x,
184+
target_y,
90185
)
186+
end
187+
188+
"""
189+
linear_interp_z(f, z, target_z)
190+
191+
Perform linear interpolation along the z-axis.
91192
92-
# Move it to GPU (if needed)
93-
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
193+
# Arguments
194+
- `f`: A vector of function values corresponding to the z-coordinates in `z`.
195+
- `z`: A vector of z-coordinates.
196+
- `target_z`: The z-coordinate at which to interpolate.
197+
198+
# Returns
199+
The linearly interpolated value at `target_z`.
200+
"""
201+
function linear_interp_z(f, z, target_z)
202+
if target_z < z[begin] || target_z > z[end]
203+
error(
204+
"target_z is out of bounds: $(target_z) not in [$(z[1]), $(z[end])]",
205+
)
206+
end
94207

95-
return map(regridder.coordinates) do coord
96-
gpuitp(totuple(coord)...)
208+
index = searchsortedfirst(z, target_z)
209+
# Handle edge cases for index
210+
if index == 1
211+
z0 = z[index]
212+
z1 = z[index + 1]
213+
f0 = f[index]
214+
f1 = f[index + 1]
215+
else
216+
z0 = z[index - 1]
217+
z1 = z[index]
218+
f0 = f[index - 1]
219+
f1 = f[index]
97220
end
221+
222+
return f0 + (target_z - z0) / (z1 - z0) * (f1 - f0)
223+
end
224+
225+
"""
226+
bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
227+
228+
Perform bilinear interpolation on a 2D plane.
229+
230+
# Arguments
231+
- `f00`: Function value at (x0, y0).
232+
- `f01`: Function value at (x0, y1).
233+
- `f10`: Function value at (x1, y0).
234+
- `f11`: Function value at (x1, y1).
235+
- `x0`: x-coordinate of the first corner.
236+
- `x1`: x-coordinate of the second corner.
237+
- `y0`: y-coordinate of the first corner.
238+
- `y1`: y-coordinate of the second corner.
239+
- `target_x`: The x-coordinate of the target point.
240+
- `target_y`: The y-coordinate of the target point.
241+
"""
242+
function bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
243+
return (
244+
(x1 - target_x) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f00 +
245+
(x1 - target_x) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f01 +
246+
(target_x - x0) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f10 +
247+
(target_x - x0) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f11
248+
)
98249
end
99250

100251
end

test/interpolations_regridder.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
using Test
2+
import ClimaUtilities
3+
import Interpolations
4+
import ClimaUtilities: Regridders
5+
import ClimaComms
6+
import ClimaCore
7+
8+
linear_interp_z = Base.get_extension(ClimaUtilities, :ClimaUtilitiesClimaCoreInterpolationsExt).InterpolationsRegridderExt.linear_interp_z
9+
bilinear_interp = Base.get_extension(ClimaUtilities, :ClimaUtilitiesClimaCoreInterpolationsExt).InterpolationsRegridderExt.bilinear_interp
10+
interpolation_3d_z = Base.get_extension(ClimaUtilities, :ClimaUtilitiesClimaCoreInterpolationsExt).InterpolationsRegridderExt.interpolation_3d_z
11+
12+
const context = ClimaComms.context()
13+
ClimaComms.init(context)
14+
15+
include("TestTools.jl")
16+
17+
@testset "Interpolation Tests" begin
18+
@testset "linear_interp_z" begin
19+
f = [1.0, 3.0, 5.0]
20+
z = [10.0, 20.0, 30.0]
21+
@test linear_interp_z(f, z, 15.0) 2.0
22+
@test linear_interp_z(f, z, 25.0) 4.0
23+
@test linear_interp_z(f, z, 10.0) 1.0
24+
@test linear_interp_z(f, z, 30.0) 5.0
25+
26+
# Out of bounds
27+
f = [1.0, 3.0]
28+
z = [10.0, 20.0]
29+
@test_throws ErrorException linear_interp_z(f, z, 5.0)
30+
@test_throws ErrorException linear_interp_z(f, z, 25.0)
31+
32+
33+
# One point
34+
f = [2.5]
35+
z = [15.0]
36+
@test_throws ErrorException linear_interp_z(f, z, 10.0)
37+
@test_throws ErrorException linear_interp_z(f, z, 20.0)
38+
39+
# Non uniform spacing
40+
f = [2.0, 4.0, 8.0]
41+
z = [1.0, 3.0, 7.0]
42+
@test linear_interp_z(f,z, 1.0) 2.0
43+
@test linear_interp_z(f,z, 3.0) 4.0
44+
@test linear_interp_z(f,z, 7.0) 8.0
45+
@test linear_interp_z(f,z, 2.0) 3.0
46+
@test linear_interp_z(f,z, 5.0) 6.0
47+
end
48+
49+
@testset "interpolation_3d_z" begin
50+
# Test cases for the main 3D interpolation function
51+
52+
# Create some sample data
53+
xs = [1.0, 2.0, 3.0]
54+
ys = [4.0, 5.0, 6.0]
55+
zs = zeros(3, 3, 8)
56+
for k in 1:8
57+
for j in 1:3
58+
for i in 1:3
59+
zs[i, j, k] = i + j + k
60+
end
61+
end
62+
end
63+
data = reshape(1:(3*3*8), 3, 3, 8)
64+
65+
# Exact point
66+
@test interpolation_3d_z(data, xs, ys, zs, xs[1], ys[1], zs[1, 1, 4]) data[1, 1, 4]
67+
68+
# Interpolated point
69+
@test interpolation_3d_z(data, xs, ys, zs, 2.5, 5.5, 7.5) 20.5
70+
71+
# Out of bounds
72+
@test_throws ErrorException interpolation_3d_z(data, xs, ys, zs, 0.5, 4.5, 3.5)
73+
@test_throws ErrorException interpolation_3d_z(data, xs, ys, zs, 2.5, 4.5, 4.5)
74+
end
75+
76+
@testset "Regrid" begin
77+
78+
lon, lat, z =
79+
collect(-180.0:1:180), collect(-90.0:1:90), collect(0.0:1.0:100.0)
80+
size3D = (361, 181, 101)
81+
data_z3D = zeros(size3D)
82+
83+
for i in 1:length(lon)
84+
for j in 1:length(lat)
85+
data_z3D[i, j, :] .= z
86+
end
87+
end
88+
dimensions3D = (lon, lat, data_z3D)
89+
90+
FT = Float64
91+
spaces = make_spherical_space(FT; context)
92+
hv_center_space = spaces.hybrid
93+
extrapolation_bc = (
94+
Interpolations.Throw(),
95+
Interpolations.Throw(),
96+
Interpolations.Throw(),
97+
)
98+
reg_hv = Regridders.InterpolationsRegridder(
99+
hv_center_space;
100+
extrapolation_bc,
101+
)
102+
regridded_z = Regridders.regrid(reg_hv, data_z3D, dimensions3D)
103+
@test maximum(ClimaCore.Fields.level(regridded_z, 2)) 0.15
104+
@test minimum(ClimaCore.Fields.level(regridded_z, 2)) 0.15
105+
106+
end
107+
end

0 commit comments

Comments
 (0)