Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion ext/ConservativeRegriddingClimaCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ GOCore.best_manifold(mesh::Meshes.AbstractCubedSphere) = GOCore.Spherical(; radi
GOCore.best_manifold(topology::Topologies.Topology2D) = GOCore.best_manifold(topology.mesh)
GOCore.best_manifold(space::ClimaCore.Spaces.AbstractSpectralElementSpace) = GOCore.best_manifold(Spaces.topology(space))


GOCore.best_manifold(field::ClimaCore.Fields.Field) = GOCore.best_manifold(getfield(field, :space))
Trees.treeify(manifold::GOCore.Spherical, field::ClimaCore.Fields.Field) = Trees.treeify(manifold, getfield(field, :space))



Expand Down Expand Up @@ -150,6 +151,28 @@ function get_element_vertices(space)
return vertices
end


function get_element_centroids(space)
# Get the indices of the vertices of the elements, in clockwise order for each element
Nh = Meshes.nelements(space.grid.topology.mesh)
Nq = Quadratures.degrees_of_freedom(Spaces.quadrature_style(space))
vertex_inds = [
CartesianIndex(i, j, 1, 1, e) # f and v are 1 for SpectralElementSpace2D
for e in 1:Nh, (i, j) in [(1, 1), (Nq, Nq)]
] # repeat the first coordinate pair at the end

# Get the lat and lon at each vertex index
coords = Fields.coordinate_field(space)
lonlat_to_usp = GO.UnitSpherical.UnitSphereFromGeographic()
centroids = map(eachslice(vertex_inds; dims = 1)) do (ind1, ind2)
coord1 = (Fields.field_values(coords.long)[ind1], Fields.field_values(coords.lat)[ind1])
coord2 = (Fields.field_values(coords.long)[ind2], Fields.field_values(coords.lat)[ind2])
usp_coord1, usp_coord2 = lonlat_to_usp.((coord1, coord2))
return GO.UnitSpherical.slerp(usp_coord1, usp_coord2, 0.5)
end
return centroids
end

### These functions are used to facilitate storing a single value per element on a field
### rather than one value per node.
"""
Expand Down
3 changes: 3 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ GeoInterface = "cf35fbd7-0cd7-5166-be24-54bfbe79505f"
GeometryOps = "3251bfac-6a57-4b6d-aa61-ac1fef2975ab"
GeometryOpsCore = "05efe853-fabf-41c8-927e-7063c8b9f013"
Healpix = "9f4e344d-96bc-545a-84a3-ae6b9e1b672b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Oceananigans = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09"
RingGrids = "d1845624-ad4f-453b-8ff4-a8db365bf3a7"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SimplexQuad = "f5de1e94-2f77-472d-8c3b-9d09f580ee5e"
SortTileRecursiveTree = "746ee33f-1797-42c2-866d-db2fce69d14d"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpeedyWeather = "9e226e20-d153-4fed-8a5b-493def4f21a9"
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ using Test, SafeTestsets
@safetestset "Extensions: Oceananigans" begin include("extensions/oceananigans.jl") end
@safetestset "Extensions: ClimaCore" begin include("extensions/climacore.jl") end
@safetestset "Extensions: Healpix" begin include("extensions/healpix.jl") end

@safetestset "Full sweat test" begin include("sweat.jl") end
end
229 changes: 229 additions & 0 deletions test/sweat.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Sweat test for all known tree types against each other

using ConservativeRegridding
using ConservativeRegridding: ConservativeRegridding as CR, Trees
using Test
import GeometryOps as GO, GeoInterface as GI

import ClimaCore, Oceananigans, Healpix, RingGrids

const ClimaCoreExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingClimaCoreExt)
const OceananigansExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingOceananigansExt)
const HealpixExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingHealpixExt)
const RingGridsExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingRingGridsExt)

function test_integral_is_conserved(regridder, tree1, values1, tree2, values2, final_values; rtol = sqrt(eps(Float64)))
tree1_areas = ConservativeRegridding.areas(GO.Spherical(), Trees.treeify(tree1))
tree2_areas = ConservativeRegridding.areas(GO.Spherical(), Trees.treeify(tree2))

@test sum(values1 .* tree1_areas) ≈ sum(final_values .* tree2_areas) rtol=rtol
@test sum(values2 .* tree2_areas) ≈ sum(final_values .* tree1_areas) rtol=rtol
end

function test_intersection_areas_agree(regridder, tree1, tree2; rtol = sqrt(eps(Float64)))
@test sum(regridder.intersections, dims=2)[:, 1] ≈ regridder.dst_areas rtol=rtol
@test sum(regridder.intersections, dims=1)[1, :] ≈ regridder.src_areas rtol=rtol
end


function zero_field!(field, values)
set_field_values!(field, values, (x, y, z = 0) -> 0)
end

# function set_field_values!(field::Oceananigans.Field, values, fun)
# Oceananigans.set!(field, fun)
# values .= vec(Oceananigans.interior(field))
# end
# function set_field_values!(field::ClimaCore.Fields.Field, values, fun)
# space = getfield(field, :space)
# centroids_latlong = GO.UnitSpherical.GeographicFromUnitSphere().(ClimaCoreExt.get_element_centroids(space))
# values .= splat(fun).(centroids_latlong)
# # ClimaCoreExt.set_value_per_element!(field, elems)
# end
# function set_field_values!(field::Healpix.HealpixMap, values, fun)
# idxs = 1:length(field.pixels)
# vals = (
# begin
# theta, phi = Healpix.pix2ang(field, idx)
# lat = deg2rad(90-theta)
# lon = deg2rad(phi)
# fun(lon, lat)
# end
# for idx in idxs
# )

# values .= vals
# end


import SimplexQuad
using LinearAlgebra: cross, dot, norm
import IterTools

struct SphericalPolygonIntegrator{X, W}
x::X
w::W
function SphericalPolygonIntegrator(; order=7)
X, W = SimplexQuad.simplexquad(order, 2) # points in barycentric-like coords on unit simplex
new{typeof(X), typeof(W)}(X, W)
end
end

function (integrator::SphericalPolygonIntegrator)(vertices::AbstractVector{<:GO.UnitSpherical.UnitSphericalPoint}, f)
# Reference triangle: unit 2-simplex
X, W = integrator.x, integrator.w # points in barycentric-like coords on unit simplex

total = 0.0
A = vertices[1]
for i in 2:length(vertices)-1
B, C = vertices[i], vertices[i+1]
det_ABC = dot(A, cross(B, C))
for k in axes(X, 1)
ξ1, ξ2 = X[k, 1], X[k, 2]
ξ0 = 1 - ξ1 - ξ2
p = ξ1 * A + ξ2 * B + ξ0 * C
np = norm(p)
s = p / np
J = abs(det_ABC) / np^3
total += W[k] * f(s) * J
end
end
return total
end


function set_field_values!(field, values, fun; integrator = SphericalPolygonIntegrator(; order=7))
polys = IterTools.ivec(Trees.getcell(Trees.treeify(field)))
values .= Iterators.map(polys) do poly
integrator(GI.getpoint(GI.getexterior(poly)), p -> fun((GO.UnitSpherical.GeographicFromUnitSphere()(p))...))
end
end






oceananigans_latlong_grid = Oceananigans.LatitudeLongitudeGrid(size=(360, 180, 1), longitude=(0, 360), latitude=(-90, 90), z = (0, 1), radius = GO.Spherical().radius)
oceananigans_tripolar_grid = Oceananigans.TripolarGrid(size=(360, 180, 1), fold_topology = Oceananigans.RightFaceFolded)

oceananigans_latlong_field = Oceananigans.CenterField(oceananigans_latlong_grid)
oceananigans_tripolar_field = Oceananigans.CenterField(oceananigans_tripolar_grid)

oceananigans_latlong_vals = vec(Oceananigans.interior(oceananigans_latlong_field))
oceananigans_tripolar_vals = vec(Oceananigans.interior(oceananigans_tripolar_field))

climacore_cubedsphere_grid = ClimaCore.CommonSpaces.CubedSphereSpace(;
radius = GO.Spherical().radius,
n_quad_points = 2,
h_elem = 64,
)
climacore_cubedsphere_field = ClimaCore.Fields.ones(climacore_cubedsphere_grid)
climacore_cubedsphere_vals = zeros(6*climacore_cubedsphere_grid.grid.topology.mesh.ne^2)

climacore_cubedsphere_gilbert_ordered_grid = let
device = ClimaCore.ClimaComms.device()
context = ClimaCore.ClimaComms.context(device)
h_elem = 64
h_mesh = ClimaCore.Meshes.EquiangularCubedSphere(ClimaCore.Domains.SphereDomain{Float64}(GO.Spherical().radius), h_elem)
h_topology = ClimaCore.Topologies.Topology2D(context, h_mesh, ClimaCore.Topologies.spacefillingcurve(h_mesh))
ClimaCore.CommonSpaces.CubedSphereSpace(;
radius = h_mesh.domain.radius,
n_quad_points = 2,
h_elem = h_elem,
h_mesh = h_mesh,
h_topology = h_topology,
)
end
climacore_cubedsphere_gilbert_ordered_field = ClimaCore.Fields.ones(climacore_cubedsphere_gilbert_ordered_grid)
climacore_cubedsphere_gilbert_ordered_vals = zeros(6*climacore_cubedsphere_gilbert_ordered_grid.grid.topology.mesh.ne^2)

healpix_nested_order_field = Healpix.HealpixMap{Float64, Healpix.NestedOrder}(64)
healpix_nested_order_vals = healpix_nested_order_field.pixels
healpix_ring_order_field = Healpix.HealpixMap{Float64, Healpix.RingOrder}(64)
healpix_ring_order_vals = healpix_ring_order_field.pixels

oceananigans_fields = [
("Oceananigans longitude-latitude grid", oceananigans_latlong_field, oceananigans_latlong_vals),
("Oceananigans tripolar grid", oceananigans_tripolar_field, oceananigans_tripolar_vals),
]

healpix_fields = [
("Healpix nested order grid", healpix_nested_order_field, healpix_nested_order_vals),
("Healpix ring order grid", healpix_ring_order_field, healpix_ring_order_vals),
]

climacore_fields = [
("ClimaCore cubed sphere grid", climacore_cubedsphere_field, climacore_cubedsphere_vals),
("ClimaCore cubed sphere grid (Gilbert ordered)", climacore_cubedsphere_gilbert_ordered_field, climacore_cubedsphere_gilbert_ordered_vals),
]

fields = [oceananigans_fields..., climacore_fields..., healpix_fields...]

regridder_construction_times = Pair{Tuple{String, String}, Float64}[]
@testset "Sweat test" begin
@testset "Sweat test: $name1 -> $name2" for (i, (name1, field1, vals1)) in enumerate(fields), (j, (name2, field2, vals2)) in enumerate(fields)
tic = time()
regridder = @test_nowarn ConservativeRegridding.Regridder(GO.Spherical(), field2, field1; normalize = false)
toc = time()
push!(regridder_construction_times, (name1, name2) => toc - tic)

# Test that the areas are correct approximately
if !(field2 isa Oceananigans.Field && field2.grid isa Oceananigans.TripolarGrid) &&
!(field1 isa Oceananigans.Field && field1.grid isa Oceananigans.TripolarGrid)
test_intersection_areas_agree(regridder, field1, field2)
end

zero_field!(field1, vals1)
zero_field!(field2, vals2)

set_field_values!(field1, vals1, ConservativeRegridding.VortexField(; lat0_rad = deg2rad(80)))
ConservativeRegridding.regrid!(vals2, regridder, vals1)

vals2_regridded = vals2[:]
vals2_analytical = vals2[:]

# Test that the areas are correct approximately
if !(field2 isa Oceananigans.Field && field2.grid isa Oceananigans.TripolarGrid) &&
!(field1 isa Oceananigans.Field && field1.grid isa Oceananigans.TripolarGrid)
# Oceananigans tripolar grid does not cover the globe
test_intersection_areas_agree(regridder, field1, field2)
else
# continue
end
i == j && continue

# if field2 isa ClimaCore.Fields.Field # TODO: haven't figured out how this can work yet
# continue
# end

# if field2 isa Healpix.HealpixMap || field1 isa Healpix.HealpixMap
# # Some unknown issue with healpix grids where the regridding to them
# # is not conserving the integral? Not sure what is going on there.
# continue
# end

@testset "Integral is conserved w.r.t. analytical values" begin
for (fun_name, fun_to_test) in [
("Longitude field", ConservativeRegridding.LongitudeField()),
("Sinusoid field", ConservativeRegridding.SinusoidField()),
("Harmonic field", ConservativeRegridding.HarmonicField()),
("Gulf stream field", ConservativeRegridding.GulfStreamField()),
("Vortex field", ConservativeRegridding.VortexField(; lat0_rad = deg2rad(80))),
]
@testset "$fun_name" begin
set_field_values!(field1, vals1, fun_to_test)
# zero_field!(field2, vals2)

ConservativeRegridding.regrid!(vals2, regridder, vals1)
vals2_regridded .= vals2

set_field_values!(field2, vals2, fun_to_test)
vals2_analytical .= vals2

@test sum(abs.(vals2_regridded) .* regridder.dst_areas) ≈ sum(abs.(vals2_analytical) .* regridder.dst_areas) rtol=1e-6
end
end
end
end
end
Loading