Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/src/inputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ Read more about this feature in the page about [`DataHandler`](@ref datahandling
default, the `Throw` condition is used, meaning that interpolating onto a point
that is outside the range of definition of the data is not allowed. Other
boundary conditions are allowed. With the `Flat` boundary condition, when
interpolating outside of the range of definition, return the value of the
of closest boundary is used instead.
interpolating outside of the range of definition, the value of the closest
boundary is used instead.

Another boundary condition that is often useful is `PeriodicCalendar`, which
repeats data over and over.
Expand Down
204 changes: 196 additions & 8 deletions ext/InterpolationsRegridderExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,209 @@ Regrid the given data as defined on the given dimensions to the `target_space` i
This function is allocating.
"""
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
# TODO: There is room for improvement in this function...

FT = ClimaCore.Spaces.undertype(regridder.target_space)
dimensions_FT = map(d -> FT.(d), dimensions)

# Make a linear spline
itp = Intp.extrapolate(
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
regridder.extrapolation_bc,
coordinates = ClimaCore.Fields.coordinate_field(regridder.target_space)
device = ClimaComms.device(regridder.target_space)

has_3d_z = length(size(last(dimensions))) == 3
if eltype(coordinates) <: ClimaCore.Geometry.LatLongZPoint && has_3d_z
# If we have 3D altitudes, we do linear in the vertical and bilinear
# horizontal separately
@warn "Ignoring boundary conditions, implementing Periodic, Flat, Flat"

adapted_data = Adapt.adapt(ClimaComms.array_type(regridder.target_space), data)
xs, ys, zs = dimensions_FT
adapted_xs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), xs)
adapted_ys = Adapt.adapt(ClimaComms.array_type(regridder.target_space), ys)
adapted_zs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), zs)

return ClimaComms.allowscalar(ClimaComms.device(regridder.target_space)) do
map(regridder.coordinates) do coord
interpolation_3d_z(
adapted_data,
adapted_xs, adapted_ys, adapted_zs,
totuple(coord)...,
)
end
end
else
# Make a linear spline
itp = Intp.extrapolate(
Intp.interpolate(
dimensions_FT,
FT.(data),
Intp.Gridded(Intp.Linear()),
),
regridder.extrapolation_bc,
)

# Move it to GPU (if needed)
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)

return map(regridder.coordinates) do coord
gpuitp(totuple(coord)...)
end
end
end

"""
interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
Perform bilinear + vertical interpolation on a 3D dataset.
This function first performs linear interpolation along the z-axis at the four
corners of the cell containing the target (x, y) point. Then, it performs
bilinear interpolation in the x-y plane using the z-interpolated values.
Periodic is implemented on the x direction, Flat on the other ones.
# Arguments
- `data`: A 3D array of data values.
- `xs`: A vector of x-coordinates corresponding to the first dimension of `data`.
- `ys`: A vector of y-coordinates corresponding to the second dimension of `data`.
- `zs`: A 3D array of z-coordinates. `zs[i, j, :]` provides the z-coordinates for the data point `data[i, j, :]`.
- `target_x`: The x-coordinate of the target point.
- `target_y`: The y-coordinate of the target point.
- `target_z`: The z-coordinate of the target point.
"""
function interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
# Check boundaries
# if target_x < xs[begin] || target_x > xs[end]
# error(
# "target_x is out of bounds: $(target_x) not in [$(xs[1]), $(xs[end])]",
# )
# end
# if target_y < ys[begin] || target_y > ys[end]
# error(
# "target_y is out of bounds: $(target_y) not in [$(ys[1]), $(ys[end])]",
# )
# end

# Find nearest neighbors
x_period = xs[end] - xs[begin]
target_x = mod(target_x, x_period)

x_index = searchsortedfirst(xs, target_x)
y_index = searchsortedfirst(ys, target_y)

x0_index = x_index == 1 ? x_index : x_index - 1
x1_index = x0_index + 1

y0_index = y_index == 1 ? y_index : y_index - 1
# Flat
y0_index = clamp(y0_index, 1, length(ys) - 1)
y1_index = y0_index + 1
if y0_index == 1
target_y = ys[y0_index]
end
if y1_index == length(ys)
target_y = ys[y1_index]
end


# Interpolate in z-direction

z00 = @view zs[x0_index, y0_index, :]
z01 = @view zs[x0_index, y1_index, :]
z10 = @view zs[x1_index, y0_index, :]
z11 = @view zs[x1_index, y1_index, :]

f00 = linear_interp_z(view(data,x0_index, y0_index, :), z00, target_z)
f01 = linear_interp_z(view(data,x0_index, y1_index, :), z01, target_z)
f10 = linear_interp_z(view(data,x1_index, y0_index, :), z10, target_z)
f11 = linear_interp_z(view(data,x1_index, y1_index, :), z11, target_z)

# Bilinear interpolation in x-y plane
val = bilinear_interp(
f00,
f01,
f10,
f11,
xs[x0_index],
xs[x1_index],
ys[y0_index],
ys[y1_index],
target_x,
target_y,
)

# Move it to GPU (if needed)
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
return val
end

"""
linear_interp_z(f, z, target_z)
Perform linear interpolation along the z-axis.
# Arguments
- `f`: A vector of function values corresponding to the z-coordinates in `z`.
- `z`: A vector of z-coordinates.
- `target_z`: The z-coordinate at which to interpolate.
# Returns
The linearly interpolated value at `target_z`.
"""
function linear_interp_z(f, z, target_z)
# if target_z < z[begin] || target_z > z[end]
# error(
# "target_z is out of bounds: $(target_z) not in [$(z[1]), $(z[end])]",
# )
# end

index = searchsortedfirst(z, target_z)
# Handle edge cases for index
# Flat
if index == 1
z0 = z[index]
z1 = z[index + 1]
f0 = f[index]
f1 = f[index + 1]
else
z0 = z[index - 1]
z1 = z[index]
f0 = f[index - 1]
f1 = f[index]
end

return map(regridder.coordinates) do coord
gpuitp(totuple(coord)...)
if index == 1
target_z = z[index]
end
if index == length(z) - 1
target_z = z[index + 1]
end
val = f0 + (target_z - z0) / (z1 - z0) * (f1 - f0)
return val
end

"""
bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
Perform bilinear interpolation on a 2D plane.
# Arguments
- `f00`: Function value at (x0, y0).
- `f01`: Function value at (x0, y1).
- `f10`: Function value at (x1, y0).
- `f11`: Function value at (x1, y1).
- `x0`: x-coordinate of the first corner.
- `x1`: x-coordinate of the second corner.
- `y0`: y-coordinate of the first corner.
- `y1`: y-coordinate of the second corner.
- `target_x`: The x-coordinate of the target point.
- `target_y`: The y-coordinate of the target point.
"""
function bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y)
val = (
(x1 - target_x) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f00 +
(x1 - target_x) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f01 +
(target_x - x0) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f10 +
(target_x - x0) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f11
)
return val
end

end
136 changes: 136 additions & 0 deletions test/interpolations_regridder.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
using Test
import ClimaUtilities
import Interpolations
import ClimaUtilities: Regridders
import ClimaComms
import ClimaCore

linear_interp_z =
Base.get_extension(
ClimaUtilities,
:ClimaUtilitiesClimaCoreInterpolationsExt,
).InterpolationsRegridderExt.linear_interp_z
bilinear_interp =
Base.get_extension(
ClimaUtilities,
:ClimaUtilitiesClimaCoreInterpolationsExt,
).InterpolationsRegridderExt.bilinear_interp
interpolation_3d_z =
Base.get_extension(
ClimaUtilities,
:ClimaUtilitiesClimaCoreInterpolationsExt,
).InterpolationsRegridderExt.interpolation_3d_z

const context = ClimaComms.context()
ClimaComms.init(context)

include("TestTools.jl")

@testset "Interpolation Tests" begin
@testset "linear_interp_z" begin
f = [1.0, 3.0, 5.0]
z = [10.0, 20.0, 30.0]
@test linear_interp_z(f, z, 15.0) 2.0
@test linear_interp_z(f, z, 25.0) 4.0
@test linear_interp_z(f, z, 10.0) 1.0
@test linear_interp_z(f, z, 30.0) 5.0

# Out of bounds
f = [1.0, 3.0]
z = [10.0, 20.0]
@test_throws ErrorException linear_interp_z(f, z, 5.0)
@test_throws ErrorException linear_interp_z(f, z, 25.0)


# One point
f = [2.5]
z = [15.0]
@test_throws ErrorException linear_interp_z(f, z, 10.0)
@test_throws ErrorException linear_interp_z(f, z, 20.0)

# Non uniform spacing
f = [2.0, 4.0, 8.0]
z = [1.0, 3.0, 7.0]
@test linear_interp_z(f, z, 1.0) 2.0
@test linear_interp_z(f, z, 3.0) 4.0
@test linear_interp_z(f, z, 7.0) 8.0
@test linear_interp_z(f, z, 2.0) 3.0
@test linear_interp_z(f, z, 5.0) 6.0
end

@testset "interpolation_3d_z" begin
# Test cases for the main 3D interpolation function

# Create some sample data
xs = [1.0, 2.0, 3.0]
ys = [4.0, 5.0, 6.0]
zs = zeros(3, 3, 8)
for k in 1:8
for j in 1:3
for i in 1:3
zs[i, j, k] = i + j + k
end
end
end
data = reshape(1:(3 * 3 * 8), 3, 3, 8)

# Exact point
@test interpolation_3d_z(data, xs, ys, zs, xs[1], ys[1], zs[1, 1, 4])
data[1, 1, 4]

# Interpolated point
@test interpolation_3d_z(data, xs, ys, zs, 2.5, 5.5, 7.5) 20.5

# Out of bounds
@test_throws ErrorException interpolation_3d_z(
data,
xs,
ys,
zs,
0.5,
4.5,
3.5,
)
@test_throws ErrorException interpolation_3d_z(
data,
xs,
ys,
zs,
2.5,
4.5,
4.5,
)
end

@testset "Regrid" begin

lon, lat, z =
collect(-180.0:1:180), collect(-90.0:1:90), collect(0.0:1.0:100.0)
size3D = (361, 181, 101)
data_z3D = zeros(size3D)

for i in 1:length(lon)
for j in 1:length(lat)
data_z3D[i, j, :] .= z
end
end
dimensions3D = (lon, lat, data_z3D)

FT = Float64
spaces = make_spherical_space(FT; context)
hv_center_space = spaces.hybrid
extrapolation_bc = (
Interpolations.Throw(),
Interpolations.Throw(),
Interpolations.Throw(),
)
reg_hv = Regridders.InterpolationsRegridder(
hv_center_space;
extrapolation_bc,
)
regridded_z = Regridders.regrid(reg_hv, data_z3D, dimensions3D)
@test maximum(ClimaCore.Fields.level(regridded_z, 2)) 0.15
@test minimum(ClimaCore.Fields.level(regridded_z, 2)) 0.15

end
end