diff --git a/docs/src/inputs.md b/docs/src/inputs.md index 8fc2d0bb..787f99a9 100644 --- a/docs/src/inputs.md +++ b/docs/src/inputs.md @@ -117,8 +117,8 @@ Read more about this feature in the page about [`DataHandler`](@ref datahandling default, the `Throw` condition is used, meaning that interpolating onto a point that is outside the range of definition of the data is not allowed. Other boundary conditions are allowed. With the `Flat` boundary condition, when -interpolating outside of the range of definition, return the value of the -of closest boundary is used instead. +interpolating outside of the range of definition, the value of the closest +boundary is used instead. Another boundary condition that is often useful is `PeriodicCalendar`, which repeats data over and over. diff --git a/ext/InterpolationsRegridderExt.jl b/ext/InterpolationsRegridderExt.jl index f64447aa..709a8162 100644 --- a/ext/InterpolationsRegridderExt.jl +++ b/ext/InterpolationsRegridderExt.jl @@ -80,21 +80,209 @@ Regrid the given data as defined on the given dimensions to the `target_space` i This function is allocating. """ function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions) + # TODO: There is room for improvement in this function... + FT = ClimaCore.Spaces.undertype(regridder.target_space) dimensions_FT = map(d -> FT.(d), dimensions) - # Make a linear spline - itp = Intp.extrapolate( - Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())), - regridder.extrapolation_bc, + coordinates = ClimaCore.Fields.coordinate_field(regridder.target_space) + device = ClimaComms.device(regridder.target_space) + + has_3d_z = length(size(last(dimensions))) == 3 + if eltype(coordinates) <: ClimaCore.Geometry.LatLongZPoint && has_3d_z + # If we have 3D altitudes, we do linear in the vertical and bilinear + # horizontal separately + @warn "Ignoring boundary conditions, implementing Periodic, Flat, Flat" + + adapted_data = Adapt.adapt(ClimaComms.array_type(regridder.target_space), data) + xs, ys, zs = dimensions_FT + adapted_xs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), xs) + adapted_ys = Adapt.adapt(ClimaComms.array_type(regridder.target_space), ys) + adapted_zs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), zs) + + return ClimaComms.allowscalar(ClimaComms.device(regridder.target_space)) do + map(regridder.coordinates) do coord + interpolation_3d_z( + adapted_data, + adapted_xs, adapted_ys, adapted_zs, + totuple(coord)..., + ) + end + end + else + # Make a linear spline + itp = Intp.extrapolate( + Intp.interpolate( + dimensions_FT, + FT.(data), + Intp.Gridded(Intp.Linear()), + ), + regridder.extrapolation_bc, + ) + + # Move it to GPU (if needed) + gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) + + return map(regridder.coordinates) do coord + gpuitp(totuple(coord)...) + end + end +end + +""" + interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z) + +Perform bilinear + vertical interpolation on a 3D dataset. + +This function first performs linear interpolation along the z-axis at the four +corners of the cell containing the target (x, y) point. Then, it performs +bilinear interpolation in the x-y plane using the z-interpolated values. + +Periodic is implemented on the x direction, Flat on the other ones. + +# Arguments +- `data`: A 3D array of data values. +- `xs`: A vector of x-coordinates corresponding to the first dimension of `data`. +- `ys`: A vector of y-coordinates corresponding to the second dimension of `data`. +- `zs`: A 3D array of z-coordinates. `zs[i, j, :]` provides the z-coordinates for the data point `data[i, j, :]`. +- `target_x`: The x-coordinate of the target point. +- `target_y`: The y-coordinate of the target point. +- `target_z`: The z-coordinate of the target point. +""" +function interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z) + # Check boundaries + # if target_x < xs[begin] || target_x > xs[end] + # error( + # "target_x is out of bounds: $(target_x) not in [$(xs[1]), $(xs[end])]", + # ) + # end + # if target_y < ys[begin] || target_y > ys[end] + # error( + # "target_y is out of bounds: $(target_y) not in [$(ys[1]), $(ys[end])]", + # ) + # end + + # Find nearest neighbors + x_period = xs[end] - xs[begin] + target_x = mod(target_x, x_period) + + x_index = searchsortedfirst(xs, target_x) + y_index = searchsortedfirst(ys, target_y) + + x0_index = x_index == 1 ? x_index : x_index - 1 + x1_index = x0_index + 1 + + y0_index = y_index == 1 ? y_index : y_index - 1 + # Flat + y0_index = clamp(y0_index, 1, length(ys) - 1) + y1_index = y0_index + 1 + if y0_index == 1 + target_y = ys[y0_index] + end + if y1_index == length(ys) + target_y = ys[y1_index] + end + + + # Interpolate in z-direction + + z00 = @view zs[x0_index, y0_index, :] + z01 = @view zs[x0_index, y1_index, :] + z10 = @view zs[x1_index, y0_index, :] + z11 = @view zs[x1_index, y1_index, :] + + f00 = linear_interp_z(view(data,x0_index, y0_index, :), z00, target_z) + f01 = linear_interp_z(view(data,x0_index, y1_index, :), z01, target_z) + f10 = linear_interp_z(view(data,x1_index, y0_index, :), z10, target_z) + f11 = linear_interp_z(view(data,x1_index, y1_index, :), z11, target_z) + + # Bilinear interpolation in x-y plane + val = bilinear_interp( + f00, + f01, + f10, + f11, + xs[x0_index], + xs[x1_index], + ys[y0_index], + ys[y1_index], + target_x, + target_y, ) - # Move it to GPU (if needed) - gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp) + return val +end + +""" + linear_interp_z(f, z, target_z) + +Perform linear interpolation along the z-axis. + +# Arguments +- `f`: A vector of function values corresponding to the z-coordinates in `z`. +- `z`: A vector of z-coordinates. +- `target_z`: The z-coordinate at which to interpolate. + +# Returns +The linearly interpolated value at `target_z`. +""" +function linear_interp_z(f, z, target_z) + # if target_z < z[begin] || target_z > z[end] + # error( + # "target_z is out of bounds: $(target_z) not in [$(z[1]), $(z[end])]", + # ) + # end + + index = searchsortedfirst(z, target_z) + # Handle edge cases for index + # Flat + if index == 1 + z0 = z[index] + z1 = z[index + 1] + f0 = f[index] + f1 = f[index + 1] + else + z0 = z[index - 1] + z1 = z[index] + f0 = f[index - 1] + f1 = f[index] + end - return map(regridder.coordinates) do coord - gpuitp(totuple(coord)...) + if index == 1 + target_z = z[index] end + if index == length(z) - 1 + target_z = z[index + 1] + end + val = f0 + (target_z - z0) / (z1 - z0) * (f1 - f0) + return val +end + +""" + bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y) + +Perform bilinear interpolation on a 2D plane. + +# Arguments +- `f00`: Function value at (x0, y0). +- `f01`: Function value at (x0, y1). +- `f10`: Function value at (x1, y0). +- `f11`: Function value at (x1, y1). +- `x0`: x-coordinate of the first corner. +- `x1`: x-coordinate of the second corner. +- `y0`: y-coordinate of the first corner. +- `y1`: y-coordinate of the second corner. +- `target_x`: The x-coordinate of the target point. +- `target_y`: The y-coordinate of the target point. +""" +function bilinear_interp(f00, f01, f10, f11, x0, x1, y0, y1, target_x, target_y) + val = ( + (x1 - target_x) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f00 + + (x1 - target_x) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f01 + + (target_x - x0) * (y1 - target_y) / ((x1 - x0) * (y1 - y0)) * f10 + + (target_x - x0) * (target_y - y0) / ((x1 - x0) * (y1 - y0)) * f11 + ) + return val end end diff --git a/test/interpolations_regridder.jl b/test/interpolations_regridder.jl new file mode 100644 index 00000000..ceeea37e --- /dev/null +++ b/test/interpolations_regridder.jl @@ -0,0 +1,136 @@ +using Test +import ClimaUtilities +import Interpolations +import ClimaUtilities: Regridders +import ClimaComms +import ClimaCore + +linear_interp_z = + Base.get_extension( + ClimaUtilities, + :ClimaUtilitiesClimaCoreInterpolationsExt, + ).InterpolationsRegridderExt.linear_interp_z +bilinear_interp = + Base.get_extension( + ClimaUtilities, + :ClimaUtilitiesClimaCoreInterpolationsExt, + ).InterpolationsRegridderExt.bilinear_interp +interpolation_3d_z = + Base.get_extension( + ClimaUtilities, + :ClimaUtilitiesClimaCoreInterpolationsExt, + ).InterpolationsRegridderExt.interpolation_3d_z + +const context = ClimaComms.context() +ClimaComms.init(context) + +include("TestTools.jl") + +@testset "Interpolation Tests" begin + @testset "linear_interp_z" begin + f = [1.0, 3.0, 5.0] + z = [10.0, 20.0, 30.0] + @test linear_interp_z(f, z, 15.0) ≈ 2.0 + @test linear_interp_z(f, z, 25.0) ≈ 4.0 + @test linear_interp_z(f, z, 10.0) ≈ 1.0 + @test linear_interp_z(f, z, 30.0) ≈ 5.0 + + # Out of bounds + f = [1.0, 3.0] + z = [10.0, 20.0] + @test_throws ErrorException linear_interp_z(f, z, 5.0) + @test_throws ErrorException linear_interp_z(f, z, 25.0) + + + # One point + f = [2.5] + z = [15.0] + @test_throws ErrorException linear_interp_z(f, z, 10.0) + @test_throws ErrorException linear_interp_z(f, z, 20.0) + + # Non uniform spacing + f = [2.0, 4.0, 8.0] + z = [1.0, 3.0, 7.0] + @test linear_interp_z(f, z, 1.0) ≈ 2.0 + @test linear_interp_z(f, z, 3.0) ≈ 4.0 + @test linear_interp_z(f, z, 7.0) ≈ 8.0 + @test linear_interp_z(f, z, 2.0) ≈ 3.0 + @test linear_interp_z(f, z, 5.0) ≈ 6.0 + end + + @testset "interpolation_3d_z" begin + # Test cases for the main 3D interpolation function + + # Create some sample data + xs = [1.0, 2.0, 3.0] + ys = [4.0, 5.0, 6.0] + zs = zeros(3, 3, 8) + for k in 1:8 + for j in 1:3 + for i in 1:3 + zs[i, j, k] = i + j + k + end + end + end + data = reshape(1:(3 * 3 * 8), 3, 3, 8) + + # Exact point + @test interpolation_3d_z(data, xs, ys, zs, xs[1], ys[1], zs[1, 1, 4]) ≈ + data[1, 1, 4] + + # Interpolated point + @test interpolation_3d_z(data, xs, ys, zs, 2.5, 5.5, 7.5) ≈ 20.5 + + # Out of bounds + @test_throws ErrorException interpolation_3d_z( + data, + xs, + ys, + zs, + 0.5, + 4.5, + 3.5, + ) + @test_throws ErrorException interpolation_3d_z( + data, + xs, + ys, + zs, + 2.5, + 4.5, + 4.5, + ) + end + + @testset "Regrid" begin + + lon, lat, z = + collect(-180.0:1:180), collect(-90.0:1:90), collect(0.0:1.0:100.0) + size3D = (361, 181, 101) + data_z3D = zeros(size3D) + + for i in 1:length(lon) + for j in 1:length(lat) + data_z3D[i, j, :] .= z + end + end + dimensions3D = (lon, lat, data_z3D) + + FT = Float64 + spaces = make_spherical_space(FT; context) + hv_center_space = spaces.hybrid + extrapolation_bc = ( + Interpolations.Throw(), + Interpolations.Throw(), + Interpolations.Throw(), + ) + reg_hv = Regridders.InterpolationsRegridder( + hv_center_space; + extrapolation_bc, + ) + regridded_z = Regridders.regrid(reg_hv, data_z3D, dimensions3D) + @test maximum(ClimaCore.Fields.level(regridded_z, 2)) ≈ 0.15 + @test minimum(ClimaCore.Fields.level(regridded_z, 2)) ≈ 0.15 + + end +end