Skip to content

Conservative interpolation #2

@milankl

Description

@milankl

Documenting first assessments of the conservative interpolation.

Algorithm

See JuliaGeo/GeometryOps.jl#246 (comment) for illustration of how the algorithm works. In short we need an "interpolator" $$A$$
as a sparse matrix that takes data from one grid $$g$$ (unravelled into a vector) and interpolates it into another vector
that is then the data on the other grid $$h$$. We also need the grid cell areas of either grid, $$a_g$$, $$a_h$$.
Then we have

$$h_j = A_{j,i}*g_i / a_{h, j}$$

with Einstein summation for the $$i$$ (summing across columns in the sparse matrix - vector multiplication) and a normalisation with the area of that respective grid cell $a_{h, j}$. The cool thing is that we can "invert" this operation, meaning interpolating back from $h$ to $g$ with the tranpose $$A^T$$ of $$A$$.

$$g_j = A^T_{i,j}*h_j / a_{g, j}$$

This "inversion" isn't variance conserving (it's an interpolation after all) but it does conserve the mean (=conservative).

Intersect areas between two grids

With JuliaGeo/GeometryOps.jl#246 we are still working on computing the area intersections of two sets of grid cells and how to do this fast and accurate in spherical coordinates. Anyway, we can already test some of this by estimating an upper bound on the intersecting polygons by simply enlarging them to the bounding box (and forget about meridian wrap around too). So I did

Oceananigans Tripolar grid

Based on https://github.com/CliMA/OrthogonalSphericalShellGrids.jl/pull/64 I define ClimaOcean's TripolaGrid grid as

using OrthogonalSphericalShellGrids
using Oceananigans
using KernelAbstractions
using GeoMakie

"""
    list_cell_vertices(grid)

Returns a list representing all horizontal grid cells in a curvilinear `grid`. 
The outpur is an Array of 6 * M `Point2` elements where `M = Nx * Ny`. Each row lists the vertices associated with a
horizontal cell in clockwise order starting from the southwest (bottom left) corner.
"""
function list_cell_vertices(grid; add_nans=true)
    Nx, Ny, _ = size(grid)
    FT = eltype(grid)

    cpu_grid = Oceananigans.on_architecture(Oceananigans.CPU(), grid)

    sw  = fill(Point2{FT}(0, 0),     1, Nx*Ny+1)
    nw  = fill(Point2{FT}(0, 0),     1, Nx*Ny+1)
    ne  = fill(Point2{FT}(0, 0),     1, Nx*Ny+1)
    se  = fill(Point2{FT}(0, 0),     1, Nx*Ny+1)
    nan = fill(Point2{FT}(NaN, NaN), 1, Nx*Ny+1)

    Oceananigans.launch!(Oceananigans.CPU(), cpu_grid, :xy, _get_vertices!, sw, nw, ne, se, grid)
    
    vertices = vcat(sw, nw, ne, se, sw)
    
    if add_nans
        vertices = vcat(vertices, nan)
    end

    return vertices
end

@kernel function _get_vertices!(sw, nw, ne, se, grid)
    i, j = @index(Global, NTuple)

    FT  = eltype(grid)
    Nx  = size(grid, 1)
    λ⁻⁻ = Oceananigans.λnode(i,   j,   1, grid, Face(), Face(), nothing)
    λ⁺⁻ = Oceananigans.λnode(i,   j+1, 1, grid, Face(), Face(), nothing)
    λ⁻⁺ = Oceananigans.λnode(i+1, j,   1, grid, Face(), Face(), nothing)
    λ⁺⁺ = Oceananigans.λnode(i+1, j+1, 1, grid, Face(), Face(), nothing)
    
    φ⁻⁻ = Oceananigans.φnode(i,   j,   1, grid, Face(), Face(), nothing)
    φ⁺⁻ = Oceananigans.φnode(i,   j+1, 1, grid, Face(), Face(), nothing)
    φ⁻⁺ = Oceananigans.φnode(i+1, j,   1, grid, Face(), Face(), nothing)
    φ⁺⁺ = Oceananigans.φnode(i+1, j+1, 1, grid, Face(), Face(), nothing)

    sw[i+(j-1)*Nx] = Point2{FT}(λ⁻⁻, φ⁻⁻)  
    nw[i+(j-1)*Nx] = Point2{FT}(λ⁻⁺, φ⁻⁺)
    ne[i+(j-1)*Nx] = Point2{FT}(λ⁺⁺, φ⁺⁺)
    se[i+(j-1)*Nx] = Point2{FT}(λ⁺⁻, φ⁺⁻)
end

grid = OrthogonalSphericalShellGrids.TripolarGrid(size = (256, 128, 1), north_poles_latitude = 60)

Upper bound on intersects

I then define the possible_intersect as

function possible_intersect(Grid1, nlat_half1, grid)
    
    e1, s1, w1, n1 = RingGrids.get_vertices(Grid1, nlat_half1)

    M = list_cell_vertices(grid, add_nans=false)

    # add reverse(..., dims=2) to test same south-north ordering of grid points
    e2 = vcat([m.data[1] for m in M[1, :]]', [m.data[2] for m in M[1, :]]')
    s2 = vcat([m.data[1] for m in M[2, :]]', [m.data[2] for m in M[2, :]]')
    w2 = vcat([m.data[1] for m in M[3, :]]', [m.data[2] for m in M[3, :]]')
    n2 = vcat([m.data[1] for m in M[4, :]]', [m.data[2] for m in M[4, :]]')

    possible_intersect(e1, s1, w1, n1, e2, s2, w2, n2)
end

function possible_intersect(e1, s1, w1, n1, e2, s2, w2, n2)
    
    npoints1 = size(e1, 2)
    npoints2 = size(e2, 2)

    # collect vertices into `faces` arrays
    faces1 = zeros(Float32, 4, 2, npoints1)
    faces2 = zeros(Float32, 4, 2, npoints2)

    for (e, s, w, n, faces) in zip((e1, e2), (s1, s2), (w1, w2), (n1, n2), (faces1, faces2))
        for ij in axes(e, 2)
            faces[1, 1, ij] = e[1, ij]
            faces[2, 1, ij] = s[1, ij]
            faces[3, 1, ij] = w[1, ij]
            faces[4, 1, ij] = n[1, ij]
            
            faces[1, 2, ij] = e[2, ij]
            faces[2, 2, ij] = s[2, ij]
            faces[3, 2, ij] = w[2, ij]
            faces[4, 2, ij] = n[2, ij]
        end
    end
        
    intersects = spzeros(Float32, npoints1, npoints2)

    for ij1 in 1:npoints1

        # get bounding box in cartesian coordinates
        x1min, x1max = extrema(view(faces1, :, 1, ij1))
        y1min, y1max = extrema(view(faces1, :, 2, ij1))

        for ij2 in 1:npoints2

            # get bounding box for other grid cell
            x2min, x2max = extrema(view(faces2, :, 1, ij2))
            y2min, y2max = extrema(view(faces2, :, 2, ij2))

            # check whether the bounding boxes overlap
            overlap_x = x1min <= x2min <= x1max || x1min <= x2max <= x1max ||
                            x2min <= x1min <= x2max || x2min <= x1max <= x2max

            overlap_y = y1min <= y2min <= y1max || y1min <= y2max <= y1max ||
                            y2min <= y1min <= y2max || y2min <= y1max <= y2max

            if overlap_x && overlap_y
                # set area simply to 1 to fill sparse array
                intersects[ij1, ij2] = 1f0
            end
        end
    end

    return intersects
end

so that

using SpeedyWeather
@time interpolator = possible_intersect(OctaminimalGaussianGrid, 64, grid)

computes the interpolator $$A$$ within a few seconds. Note that the intersection is simply set to 1 here to get an idea of the sparsity structure, memory and performance. With @asinghvi17 we are working on the intersect area computation in spherical coordinates.

  17.472415 seconds (930 allocations: 30.779 MiB, 0.03% gc time)
  16640×32769 SparseMatrixCSC{Float32, Int64} with 344233 stored entries:
  ⎡⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣠⣤⣴⣶⣶⣿⣿⣿⠿⎤
  ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣤⣶⣿⡿⠟⠛⠉⠉⠀⠀⠀⠀⠀⎥
  ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣴⣾⡿⠛⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥
  ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣴⣾⠿⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥
  ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⣴⡿⠛⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⎥
  ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣠⣾⠿⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⎥
  ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣾⠟⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥
  ⎢⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢀⣤⣾⠟⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥
  ⎢⠀⠀⠀⠀⠀⠀⣀⣠⣴⠾⠋⠉⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥
  ⎣⣠⣤⠤⠶⠞⠛⠉⠁⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎦

Given SpeedyWeather's north to south grid cell numbering, but Oceananigans south to north, this matrix resembles an anti-diagonal matrix with a bunch of off-anti-diagonals denoting overlaps between grid cells. This matrix is only a few MB large, maybe growing to 30MB or so for 1˚ atmosphere coupled to 1/4˚ ocean but nothing unmanageable.

Conservative interpolation performance

Setting some tests up with

v1 = rand(Float32, size(interpolator, 1))   # data on one grid
v2 = rand(Float32, size(interpolator, 2))  # data on the other grid

z1 = zero(v1)
z2 = zero(v2)

area1 = z1 .+ 1     # dummy vector of grid cell areas
area2 = z2 .+ 2

then we can define an in-place conservative (once the areas are correctly calculated) interpolation like so

using LinearAlgebra

function interpolate!(
    gridout::AbstractVector,
    gridin::AbstractVector,
    interpolator::AbstractMatrix,
    areaout::AbstractVector
)
    if size(interpolator) == (length(gridout), length(gridin))
        LinearAlgebra.mul!(gridout, interpolator, gridin)
    else
        LinearAlgebra.mul!(gridout, transpose(interpolator), gridin)
    end

    gridout ./= areaout
end

which automatically takes the tranpose and attemps an interpolation if the grids come in swapped as arguments. Usage/performance is then

@btime interpolate!($z1, $v2, $interpolator, $area1)
#  340.201 μs (0 allocations: 0 bytes)

@btime interpolate!($z2, $v1, $interpolator, $area2)
#  369.240 μs (0 allocations: 0 bytes)

reaching maybe milliseconds for larger grids, but that can probably also be done on the GPU!? It's a little slower than the bilinear interpolation we do in SpeedyWeather for output which however is always a 4-point interpolation whereas here we might have a conseridably larger stencil so I reckon it's ballpark very similar. Also the north-south vs south-north ordering doesn't seem to matter much (tested with reverse(..., dims=2) see above). Sure, one grid is read in backwards but given it's in RAM maybe just not an issue?

Metadata

Metadata

Assignees

No one assigned

    Labels

    grid 🌐Interpolation and co

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions