diff --git a/ext/ConservativeRegriddingClimaCoreExt.jl b/ext/ConservativeRegriddingClimaCoreExt.jl index d1c598f..ece7580 100644 --- a/ext/ConservativeRegriddingClimaCoreExt.jl +++ b/ext/ConservativeRegriddingClimaCoreExt.jl @@ -94,7 +94,8 @@ GOCore.best_manifold(mesh::Meshes.AbstractCubedSphere) = GOCore.Spherical(; radi GOCore.best_manifold(topology::Topologies.Topology2D) = GOCore.best_manifold(topology.mesh) GOCore.best_manifold(space::ClimaCore.Spaces.AbstractSpectralElementSpace) = GOCore.best_manifold(Spaces.topology(space)) - +GOCore.best_manifold(field::ClimaCore.Fields.Field) = GOCore.best_manifold(getfield(field, :space)) +Trees.treeify(manifold::GOCore.Spherical, field::ClimaCore.Fields.Field) = Trees.treeify(manifold, getfield(field, :space)) @@ -150,6 +151,28 @@ function get_element_vertices(space) return vertices end + +function get_element_centroids(space) + # Get the indices of the vertices of the elements, in clockwise order for each element + Nh = Meshes.nelements(space.grid.topology.mesh) + Nq = Quadratures.degrees_of_freedom(Spaces.quadrature_style(space)) + vertex_inds = [ + CartesianIndex(i, j, 1, 1, e) # f and v are 1 for SpectralElementSpace2D + for e in 1:Nh, (i, j) in [(1, 1), (Nq, Nq)] + ] # repeat the first coordinate pair at the end + + # Get the lat and lon at each vertex index + coords = Fields.coordinate_field(space) + lonlat_to_usp = GO.UnitSpherical.UnitSphereFromGeographic() + centroids = map(eachslice(vertex_inds; dims = 1)) do (ind1, ind2) + coord1 = (Fields.field_values(coords.long)[ind1], Fields.field_values(coords.lat)[ind1]) + coord2 = (Fields.field_values(coords.long)[ind2], Fields.field_values(coords.lat)[ind2]) + usp_coord1, usp_coord2 = lonlat_to_usp.((coord1, coord2)) + return GO.UnitSpherical.slerp(usp_coord1, usp_coord2, 0.5) + end + return centroids +end + ### These functions are used to facilitate storing a single value per element on a field ### rather than one value per node. """ diff --git a/test/Project.toml b/test/Project.toml index 064d90d..4f9a0dc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,10 +6,13 @@ GeoInterface = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" GeometryOps = "3251bfac-6a57-4b6d-aa61-ac1fef2975ab" GeometryOpsCore = "05efe853-fabf-41c8-927e-7063c8b9f013" Healpix = "9f4e344d-96bc-545a-84a3-ae6b9e1b672b" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LibGEOS = "a90b1aa1-3769-5649-ba7e-abc5a9d163eb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Oceananigans = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09" +RingGrids = "d1845624-ad4f-453b-8ff4-a8db365bf3a7" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SimplexQuad = "f5de1e94-2f77-472d-8c3b-9d09f580ee5e" SortTileRecursiveTree = "746ee33f-1797-42c2-866d-db2fce69d14d" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpeedyWeather = "9e226e20-d153-4fed-8a5b-493def4f21a9" diff --git a/test/runtests.jl b/test/runtests.jl index 618c649..02968c6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,4 +17,6 @@ using Test, SafeTestsets @safetestset "Extensions: Oceananigans" begin include("extensions/oceananigans.jl") end @safetestset "Extensions: ClimaCore" begin include("extensions/climacore.jl") end @safetestset "Extensions: Healpix" begin include("extensions/healpix.jl") end + + @safetestset "Full sweat test" begin include("sweat.jl") end end diff --git a/test/sweat.jl b/test/sweat.jl new file mode 100644 index 0000000..4d0d455 --- /dev/null +++ b/test/sweat.jl @@ -0,0 +1,229 @@ +# Sweat test for all known tree types against each other + +using ConservativeRegridding +using ConservativeRegridding: ConservativeRegridding as CR, Trees +using Test +import GeometryOps as GO, GeoInterface as GI + +import ClimaCore, Oceananigans, Healpix, RingGrids + +const ClimaCoreExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingClimaCoreExt) +const OceananigansExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingOceananigansExt) +const HealpixExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingHealpixExt) +const RingGridsExt = Base.get_extension(ConservativeRegridding, :ConservativeRegriddingRingGridsExt) + +function test_integral_is_conserved(regridder, tree1, values1, tree2, values2, final_values; rtol = sqrt(eps(Float64))) + tree1_areas = ConservativeRegridding.areas(GO.Spherical(), Trees.treeify(tree1)) + tree2_areas = ConservativeRegridding.areas(GO.Spherical(), Trees.treeify(tree2)) + + @test sum(values1 .* tree1_areas) ≈ sum(final_values .* tree2_areas) rtol=rtol + @test sum(values2 .* tree2_areas) ≈ sum(final_values .* tree1_areas) rtol=rtol +end + +function test_intersection_areas_agree(regridder, tree1, tree2; rtol = sqrt(eps(Float64))) + @test sum(regridder.intersections, dims=2)[:, 1] ≈ regridder.dst_areas rtol=rtol + @test sum(regridder.intersections, dims=1)[1, :] ≈ regridder.src_areas rtol=rtol +end + + +function zero_field!(field, values) + set_field_values!(field, values, (x, y, z = 0) -> 0) +end + +# function set_field_values!(field::Oceananigans.Field, values, fun) +# Oceananigans.set!(field, fun) +# values .= vec(Oceananigans.interior(field)) +# end +# function set_field_values!(field::ClimaCore.Fields.Field, values, fun) +# space = getfield(field, :space) +# centroids_latlong = GO.UnitSpherical.GeographicFromUnitSphere().(ClimaCoreExt.get_element_centroids(space)) +# values .= splat(fun).(centroids_latlong) +# # ClimaCoreExt.set_value_per_element!(field, elems) +# end +# function set_field_values!(field::Healpix.HealpixMap, values, fun) +# idxs = 1:length(field.pixels) +# vals = ( +# begin +# theta, phi = Healpix.pix2ang(field, idx) +# lat = deg2rad(90-theta) +# lon = deg2rad(phi) +# fun(lon, lat) +# end +# for idx in idxs +# ) + +# values .= vals +# end + + +import SimplexQuad +using LinearAlgebra: cross, dot, norm +import IterTools + +struct SphericalPolygonIntegrator{X, W} + x::X + w::W + function SphericalPolygonIntegrator(; order=7) + X, W = SimplexQuad.simplexquad(order, 2) # points in barycentric-like coords on unit simplex + new{typeof(X), typeof(W)}(X, W) + end +end + +function (integrator::SphericalPolygonIntegrator)(vertices::AbstractVector{<:GO.UnitSpherical.UnitSphericalPoint}, f) + # Reference triangle: unit 2-simplex + X, W = integrator.x, integrator.w # points in barycentric-like coords on unit simplex + + total = 0.0 + A = vertices[1] + for i in 2:length(vertices)-1 + B, C = vertices[i], vertices[i+1] + det_ABC = dot(A, cross(B, C)) + for k in axes(X, 1) + ξ1, ξ2 = X[k, 1], X[k, 2] + ξ0 = 1 - ξ1 - ξ2 + p = ξ1 * A + ξ2 * B + ξ0 * C + np = norm(p) + s = p / np + J = abs(det_ABC) / np^3 + total += W[k] * f(s) * J + end + end + return total +end + + +function set_field_values!(field, values, fun; integrator = SphericalPolygonIntegrator(; order=7)) + polys = IterTools.ivec(Trees.getcell(Trees.treeify(field))) + values .= Iterators.map(polys) do poly + integrator(GI.getpoint(GI.getexterior(poly)), p -> fun((GO.UnitSpherical.GeographicFromUnitSphere()(p))...)) + end +end + + + + + + +oceananigans_latlong_grid = Oceananigans.LatitudeLongitudeGrid(size=(360, 180, 1), longitude=(0, 360), latitude=(-90, 90), z = (0, 1), radius = GO.Spherical().radius) +oceananigans_tripolar_grid = Oceananigans.TripolarGrid(size=(360, 180, 1), fold_topology = Oceananigans.RightFaceFolded) + +oceananigans_latlong_field = Oceananigans.CenterField(oceananigans_latlong_grid) +oceananigans_tripolar_field = Oceananigans.CenterField(oceananigans_tripolar_grid) + +oceananigans_latlong_vals = vec(Oceananigans.interior(oceananigans_latlong_field)) +oceananigans_tripolar_vals = vec(Oceananigans.interior(oceananigans_tripolar_field)) + +climacore_cubedsphere_grid = ClimaCore.CommonSpaces.CubedSphereSpace(; + radius = GO.Spherical().radius, + n_quad_points = 2, + h_elem = 64, +) +climacore_cubedsphere_field = ClimaCore.Fields.ones(climacore_cubedsphere_grid) +climacore_cubedsphere_vals = zeros(6*climacore_cubedsphere_grid.grid.topology.mesh.ne^2) + +climacore_cubedsphere_gilbert_ordered_grid = let + device = ClimaCore.ClimaComms.device() + context = ClimaCore.ClimaComms.context(device) + h_elem = 64 + h_mesh = ClimaCore.Meshes.EquiangularCubedSphere(ClimaCore.Domains.SphereDomain{Float64}(GO.Spherical().radius), h_elem) + h_topology = ClimaCore.Topologies.Topology2D(context, h_mesh, ClimaCore.Topologies.spacefillingcurve(h_mesh)) + ClimaCore.CommonSpaces.CubedSphereSpace(; + radius = h_mesh.domain.radius, + n_quad_points = 2, + h_elem = h_elem, + h_mesh = h_mesh, + h_topology = h_topology, + ) +end +climacore_cubedsphere_gilbert_ordered_field = ClimaCore.Fields.ones(climacore_cubedsphere_gilbert_ordered_grid) +climacore_cubedsphere_gilbert_ordered_vals = zeros(6*climacore_cubedsphere_gilbert_ordered_grid.grid.topology.mesh.ne^2) + +healpix_nested_order_field = Healpix.HealpixMap{Float64, Healpix.NestedOrder}(64) +healpix_nested_order_vals = healpix_nested_order_field.pixels +healpix_ring_order_field = Healpix.HealpixMap{Float64, Healpix.RingOrder}(64) +healpix_ring_order_vals = healpix_ring_order_field.pixels + +oceananigans_fields = [ + ("Oceananigans longitude-latitude grid", oceananigans_latlong_field, oceananigans_latlong_vals), + ("Oceananigans tripolar grid", oceananigans_tripolar_field, oceananigans_tripolar_vals), +] + +healpix_fields = [ + ("Healpix nested order grid", healpix_nested_order_field, healpix_nested_order_vals), + ("Healpix ring order grid", healpix_ring_order_field, healpix_ring_order_vals), +] + +climacore_fields = [ + ("ClimaCore cubed sphere grid", climacore_cubedsphere_field, climacore_cubedsphere_vals), + ("ClimaCore cubed sphere grid (Gilbert ordered)", climacore_cubedsphere_gilbert_ordered_field, climacore_cubedsphere_gilbert_ordered_vals), +] + +fields = [oceananigans_fields..., climacore_fields..., healpix_fields...] + +regridder_construction_times = Pair{Tuple{String, String}, Float64}[] +@testset "Sweat test" begin + @testset "Sweat test: $name1 -> $name2" for (i, (name1, field1, vals1)) in enumerate(fields), (j, (name2, field2, vals2)) in enumerate(fields) + tic = time() + regridder = @test_nowarn ConservativeRegridding.Regridder(GO.Spherical(), field2, field1; normalize = false) + toc = time() + push!(regridder_construction_times, (name1, name2) => toc - tic) + + # Test that the areas are correct approximately + if !(field2 isa Oceananigans.Field && field2.grid isa Oceananigans.TripolarGrid) && + !(field1 isa Oceananigans.Field && field1.grid isa Oceananigans.TripolarGrid) + test_intersection_areas_agree(regridder, field1, field2) + end + + zero_field!(field1, vals1) + zero_field!(field2, vals2) + + set_field_values!(field1, vals1, ConservativeRegridding.VortexField(; lat0_rad = deg2rad(80))) + ConservativeRegridding.regrid!(vals2, regridder, vals1) + + vals2_regridded = vals2[:] + vals2_analytical = vals2[:] + + # Test that the areas are correct approximately + if !(field2 isa Oceananigans.Field && field2.grid isa Oceananigans.TripolarGrid) && + !(field1 isa Oceananigans.Field && field1.grid isa Oceananigans.TripolarGrid) + # Oceananigans tripolar grid does not cover the globe + test_intersection_areas_agree(regridder, field1, field2) + else + # continue + end + i == j && continue + + # if field2 isa ClimaCore.Fields.Field # TODO: haven't figured out how this can work yet + # continue + # end + + # if field2 isa Healpix.HealpixMap || field1 isa Healpix.HealpixMap + # # Some unknown issue with healpix grids where the regridding to them + # # is not conserving the integral? Not sure what is going on there. + # continue + # end + + @testset "Integral is conserved w.r.t. analytical values" begin + for (fun_name, fun_to_test) in [ + ("Longitude field", ConservativeRegridding.LongitudeField()), + ("Sinusoid field", ConservativeRegridding.SinusoidField()), + ("Harmonic field", ConservativeRegridding.HarmonicField()), + ("Gulf stream field", ConservativeRegridding.GulfStreamField()), + ("Vortex field", ConservativeRegridding.VortexField(; lat0_rad = deg2rad(80))), + ] + @testset "$fun_name" begin + set_field_values!(field1, vals1, fun_to_test) + # zero_field!(field2, vals2) + + ConservativeRegridding.regrid!(vals2, regridder, vals1) + vals2_regridded .= vals2 + + set_field_values!(field2, vals2, fun_to_test) + vals2_analytical .= vals2 + + @test sum(abs.(vals2_regridded) .* regridder.dst_areas) ≈ sum(abs.(vals2_analytical) .* regridder.dst_areas) rtol=1e-6 + end + end + end + end +end \ No newline at end of file