Skip to content

Commit ff6d022

Browse files
taimoorsohailnavidcyglwagnersimone-silvestri
authored
(0.99.2) Add XESMF.jl extension to use xESMF to compute tracer regridding weights (#4782)
* Added a general conservative regridder * adds an extension * first ext commit * tweak docs * cleaner * cleanup * better way to import python packages * use PyCall * use PyCall * start implementing PythonCallExt * switch to non-xarray input * comment out xarray * working? * seems to work * docstrings * cleanup * Update Diagnostics.jl * Delete src/Diagnostics/regridder.jl * add docstring to regridding_weights * ensure src and dst have same location * add test * add PythonCall in test * test install * add mwe in test * debug * add docstring * using, not import * debug * load CondaPkg before PythonCall * back * try * Delete CondaPkg.toml * try using CondaPkg in initialization * revert to previous pattern * base extention on XESMF.jl * pythoncall CI -> xesmf CI * fix xesmf tests * add PythonCall in OceananigansXESMFExt deps * use XESMF v0.1.1 * don't duplicate XESMF.jl functionality * separate extract_xesmf_coordinates_structure * delete .CondaPkg * check for locations before coordinates * locations are needed * regrid uniform field and check integral * show integral * few different grids * Update test/test_xesmf.jl Co-authored-by: Gregory L. Wagner <[email protected]> * this is how python.xESMF.regridder expects the coordinates otherwise the returned weights are geebrish! cc @glwagner * integrals match with rtol=1e-4 is this good enough? bad? * add Regridder + regrid! method for ::Regridder * add Regridder * update tests * remove empty line * remove PythonCall * add method * add method * Update OceananigansXESMFExt.jl * Update OceananigansXESMFExt.jl * move Regridder type to XESMF and regrid! method to OceananigansXESMFExt * define regrid! method and add methods to XESMF constructors * use XESMF v0.1.2 or later * enforce XESMF#ncc/define-regridder-struct * no need to specify branch * Update ext/OceananigansXESMFExt.jl Co-authored-by: Simone Silvestri <[email protected]> * Update ext/OceananigansXESMFExt.jl Co-authored-by: Simone Silvestri <[email protected]> * use XESMF v0.1.4 * bump patch release * fix test --------- Co-authored-by: Navid C. Constantinou <[email protected]> Co-authored-by: Gregory Wagner <[email protected]> Co-authored-by: Simone Silvestri <[email protected]>
1 parent 37f3014 commit ff6d022

File tree

7 files changed

+290
-10
lines changed

7 files changed

+290
-10
lines changed

.github/workflows/ci.yml

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
echo "[Reactant]" >> LocalPreferences.toml
4747
echo "xla_runtime = \"IFRT\"" >> LocalPreferences.toml
4848
49-
cat LocalPreferences.toml
49+
cat LocalPreferences.toml
5050
- uses: actions/checkout@v5
5151
- uses: julia-actions/setup-julia@v2
5252
with:
@@ -160,3 +160,32 @@ jobs:
160160
env:
161161
TEST_GROUP: "metal"
162162

163+
164+
xesmf:
165+
name: XESMF - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
166+
runs-on: ${{ matrix.os }}
167+
timeout-minutes: 60
168+
strategy:
169+
fail-fast: false
170+
matrix:
171+
version:
172+
- '1.10'
173+
os:
174+
- ubuntu-latest
175+
arch:
176+
- x64
177+
include:
178+
- os: macOS-latest
179+
arch: aarch64
180+
version: '1.10'
181+
steps:
182+
- uses: actions/checkout@v5
183+
- uses: julia-actions/setup-julia@v2
184+
with:
185+
version: ${{ matrix.version }}
186+
arch: ${{ matrix.arch }}
187+
- uses: julia-actions/cache@v2
188+
- uses: julia-actions/julia-buildpkg@v1
189+
- uses: julia-actions/julia-runtest@v1
190+
env:
191+
TEST_GROUP: "xesmf"

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Oceananigans"
22
uuid = "9e8cae18-63c1-5223-a75c-80ca9d6e9a09"
33
authors = ["Climate Modeling Alliance and contributors"]
4-
version = "0.99.1"
4+
version = "0.99.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -46,6 +46,7 @@ MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b"
4646
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4747
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
4848
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
49+
XESMF = "2e0b0046-e7a1-486f-88de-807ee8ffabe5"
4950
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
5051

5152
[extensions]
@@ -57,6 +58,7 @@ OceananigansMetalExt = "Metal"
5758
OceananigansNCDatasetsExt = "NCDatasets"
5859
OceananigansOneAPIExt = "oneAPI"
5960
OceananigansReactantExt = ["Reactant", "KernelAbstractions", "ConstructionBase"]
61+
OceananigansXESMFExt = ["XESMF"]
6062

6163
[compat]
6264
AMDGPU = "1.3.6, 2"
@@ -99,6 +101,7 @@ StaticArrays = "1"
99101
Statistics = "1.9"
100102
StructArrays = "0.4, 0.5, 0.6, 0.7"
101103
TimesDates = "0.3"
104+
XESMF = "0.1.4"
102105
julia = "1.9"
103106
oneAPI = "2.0.1"
104107

@@ -117,4 +120,4 @@ TimesDates = "bdfc003b-8df8-5c39-adcd-3a9087f5df4a"
117120
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
118121

119122
[targets]
120-
test = ["AMDGPU", "CUDA", "oneAPI", "DataDeps", "SafeTestsets", "Test", "Enzyme", "Reactant", "Metal", "CUDA_Runtime_jll", "MPIPreferences", "TimesDates", "NCDatasets"]
123+
test = ["AMDGPU", "CUDA", "oneAPI", "DataDeps", "SafeTestsets", "Test", "Enzyme", "Reactant", "Metal", "XESMF", "CUDA_Runtime_jll", "MPIPreferences", "TimesDates", "NCDatasets"]

ext/OceananigansXESMFExt.jl

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
module OceananigansXESMFExt
2+
3+
using XESMF
4+
using Oceananigans
5+
using Oceananigans.Architectures: architecture, on_architecture
6+
using Oceananigans.Fields: AbstractField, topology, location
7+
using Oceananigans.Grids: λnodes, φnodes, Center, Face, total_length
8+
9+
import Oceananigans.Fields: regrid!
10+
import XESMF: Regridder, extract_xesmf_coordinates_structure
11+
12+
function x_node_array(x::AbstractVector, Nx, Ny)
13+
return Array(repeat(view(x, 1:Nx), 1, Ny))'
14+
end
15+
function y_node_array(x::AbstractVector, Nx, Ny)
16+
return Array(repeat(view(x, 1:Ny)', Nx, 1))'
17+
end
18+
x_node_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx, 1:Ny))'
19+
20+
function x_vertex_array(x::AbstractVector, Nx, Ny)
21+
return Array(repeat(view(x, 1:Nx+1), 1, Ny+1))'
22+
end
23+
function y_vertex_array(x::AbstractVector, Nx, Ny)
24+
return Array(repeat(view(x, 1:Ny+1)', Nx+1, 1))'
25+
end
26+
x_vertex_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx+1, 1:Ny+1))'
27+
28+
y_node_array(x::AbstractMatrix, Nx, Ny) = x_node_array(x, Nx, Ny)
29+
y_vertex_array(x::AbstractMatrix, Nx, Ny) = x_vertex_array(x, Nx, Ny)
30+
31+
function extract_xesmf_coordinates_structure(dst_field::AbstractField, src_field::AbstractField)
32+
33+
ℓx, ℓy, ℓz = Oceananigans.Fields.instantiated_location(src_field)
34+
35+
dst_grid = dst_field.grid
36+
src_grid = src_field.grid
37+
38+
# Extract center coordinates from both fields
39+
λᵈ = λnodes(dst_grid, Center(), Center(), ℓz, with_halos=true)
40+
φᵈ = φnodes(dst_grid, Center(), Center(), ℓz, with_halos=true)
41+
λˢ = λnodes(src_grid, Center(), Center(), ℓz, with_halos=true)
42+
φˢ = φnodes(src_grid, Center(), Center(), ℓz, with_halos=true)
43+
44+
# Extract cell vertices
45+
λvᵈ = λnodes(dst_grid, Face(), Face(), ℓz, with_halos=true)
46+
φvᵈ = φnodes(dst_grid, Face(), Face(), ℓz, with_halos=true)
47+
λvˢ = λnodes(src_grid, Face(), Face(), ℓz, with_halos=true)
48+
φvˢ = φnodes(src_grid, Face(), Face(), ℓz, with_halos=true)
49+
50+
# Build data structures expected by xESMF
51+
Nˢx, Nˢy, Nˢz = size(src_field)
52+
Nᵈx, Nᵈy, Nᵈz = size(dst_field)
53+
54+
λᵈ = x_node_array(λᵈ, Nᵈx, Nᵈy)
55+
φᵈ = y_node_array(φᵈ, Nᵈx, Nᵈy)
56+
λˢ = x_node_array(λˢ, Nˢx, Nˢy)
57+
φˢ = y_node_array(φˢ, Nˢx, Nˢy)
58+
59+
λvᵈ = x_vertex_array(λvᵈ, Nᵈx, Nᵈy)
60+
φvᵈ = y_vertex_array(φvᵈ, Nᵈx, Nᵈy)
61+
λvˢ = x_vertex_array(λvˢ, Nˢx, Nˢy)
62+
φvˢ = y_vertex_array(φvˢ, Nˢx, Nˢy)
63+
64+
dst_coordinates = Dict("lat" => φᵈ, # φ is latitude
65+
"lon" => λᵈ, # λ is longitude
66+
"lat_b" => φvᵈ,
67+
"lon_b" => λvᵈ)
68+
69+
src_coordinates = Dict("lat" => φˢ, # φ is latitude
70+
"lon" => λˢ, # λ is longitude
71+
"lat_b" => φvˢ,
72+
"lon_b" => λvˢ)
73+
74+
return dst_coordinates, src_coordinates
75+
end
76+
77+
"""
78+
Regridder(dst_field::AbstractField, src_field::AbstractField; method="conservative")
79+
80+
Return a regridder from `src_field` to `dst_field` using the specified `method`.
81+
The regridder contains a sparse matrix with the regridding weights.
82+
The regridding weights are obtained via xESMF Python package.
83+
xESMF exposes five different regridding algorithms from the ESMF library,
84+
specified with the `method` keyword argument:
85+
86+
* `"bilinear"`: `ESMF.RegridMethod.BILINEAR`
87+
* `"conservative"`: `ESMF.RegridMethod.CONSERVE`
88+
* `"conservative_normed"`: `ESMF.RegridMethod.CONSERVE`
89+
* `"patch"`: `ESMF.RegridMethod.PATCH`
90+
* `"nearest_s2d"`: `ESMF.RegridMethod.NEAREST_STOD`
91+
* `"nearest_d2s"`: `ESMF.RegridMethod.NEAREST_DTOS`
92+
93+
where `conservative_normed` is just the conservative method with the normalization set to
94+
`ESMF.NormType.FRACAREA` instead of the default `norm_type = ESMF.NormType.DSTAREA`.
95+
96+
For more information, see the Python xESMF documentation at:
97+
98+
> https://xesmf.readthedocs.io/en/latest/notebooks/Compare_algorithms.html
99+
100+
Example
101+
=======
102+
103+
```@example
104+
using Oceananigans
105+
using XESMF
106+
107+
z = (-1, 0)
108+
tg = TripolarGrid(; size=(360, 170, 1), z, southernmost_latitude = -80)
109+
llg = LatitudeLongitudeGrid(; size=(360, 180, 1), z,
110+
longitude=(0, 360), latitude=(-82, 90))
111+
112+
src_field = CenterField(tg)
113+
dst_field = CenterField(llg)
114+
115+
regridder = Oceananigans.Fields.Regridder(dst_field, src_field, method="conservative")
116+
```
117+
"""
118+
function Regridder(dst_field::AbstractField, src_field::AbstractField; method="conservative")
119+
120+
ℓx, ℓy, ℓz = Oceananigans.Fields.instantiated_location(src_field)
121+
122+
# We only support regridding between centered fields
123+
@assert ℓx isa Center
124+
@assert ℓy isa Center
125+
@assert (ℓx, ℓy, ℓz) == Oceananigans.Fields.instantiated_location(dst_field)
126+
127+
src_Nz = size(src_field)[3]
128+
dst_Nz = size(dst_field)[3]
129+
@assert src_field.grid.z.cᵃᵃᶠ[1:src_Nz+1] == dst_field.grid.z.cᵃᵃᶠ[1:dst_Nz+1]
130+
131+
dst_coordinates, src_coordinates = extract_xesmf_coordinates_structure(dst_field, src_field)
132+
periodic = Oceananigans.Grids.topology(src_field.grid, 1) === Periodic ? true : false
133+
134+
regridder = XESMF.Regridder(src_coordinates, dst_coordinates; method, periodic)
135+
weights = regridder.weights
136+
137+
arch = architecture(src_field)
138+
139+
weights = on_architecture(arch, weights)
140+
141+
temp_src = on_architecture(architecture(src_field), regridder.src_temp)
142+
temp_dst = on_architecture(architecture(dst_field), regridder.dst_temp)
143+
144+
return XESMF.Regridder(method, weights, temp_src, temp_dst)
145+
end
146+
147+
"""
148+
regrid!(dst_field, regrider::XESMF.Regridder, src_field)
149+
150+
Regrid `src_field` onto the grid of field `dst_field` using the regrider `r`.
151+
152+
Example
153+
=======
154+
155+
```@example
156+
using Oceananigans
157+
using XESMF
158+
159+
z = (-1, 0)
160+
161+
tg = TripolarGrid(; size=(360, 170, 1), z, southernmost_latitude = -80)
162+
163+
llg = LatitudeLongitudeGrid(; size=(360, 180, 1), z,
164+
longitude=(0, 360), latitude=(-82, 90))
165+
166+
src_field = CenterField(tg)
167+
dst_field = CenterField(llg)
168+
169+
λ₀, φ₀ = 150, 30. # degrees
170+
width = 12 # degrees
171+
set!(src_field, (λ, φ, z) -> exp(-((λ - λ₀)^2 + (φ - φ₀)^2) / 2width^2))
172+
173+
regridder = XESMF.Regridder(dst_field, src_field, method="conservative")
174+
175+
regrid!(dst_field, regridder, src_field)
176+
177+
first(Field(Integral(dst_field, dims=(1, 2))))
178+
```
179+
"""
180+
function regrid!(dst_field, regridder::XESMF.Regridder, src_field)
181+
Nz = size(src_field.grid)[3]
182+
topo_z = topology(src_field)[3]()
183+
ℓz = location(src_field)[3]()
184+
185+
for k in 1:total_length(ℓz, topo_z, Nz)
186+
src = vec(interior(src_field, :, :, k))
187+
dst = vec(interior(dst_field, :, :, k))
188+
regridder(dst, src)
189+
end
190+
191+
return dst_field
192+
end
193+
194+
end # module

src/Fields/regridding_fields.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@ using KernelAbstractions: @kernel, @index
22

33
using Oceananigans.Architectures: on_architecture, architecture
44
using Oceananigans.Operators: Δzᶜᶜᶜ, Δyᶜᶜᶜ, Δxᶜᶜᶜ, Azᶜᶜᶜ
5-
using Oceananigans.Grids: hack_sind, ξnode, ηnode, rnode
5+
using Oceananigans.Grids: hack_sind, ξnode, ηnode, rnode, total_length
6+
using LinearAlgebra
67

78
using Base: ForwardOrdering
89

910
const f = Face()
1011
const c = Center()
1112

1213
"""
13-
regrid!(a, b)
14+
regrid!(dst_field, src_field)
1415
15-
Regrid field `b` onto the grid of field `a`.
16+
Regrid `src_field` onto the grid of `dst_field`.
1617
1718
Example
1819
=======
@@ -44,7 +45,8 @@ output_field[1, 1, :]
4445
0.0
4546
```
4647
"""
47-
regrid!(a, b) = regrid!(a, a.grid, b.grid, b)
48+
regrid!(dst_field, src_field) =
49+
regrid!(dst_field, dst_field.grid, src_field.grid, src_field)
4850

4951
function we_can_regrid_in_z(a, target_grid, source_grid, b)
5052
# Check that

src/Grids/abstract_grid.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,3 @@ grid_name(grid::AbstractGrid) = typeof(grid).name.wrapper
108108
TX, TY, TZ = topology(grid)
109109
return (topology_str(TX), topology_str(TY), topology_str(TZ))
110110
end
111-

test/runtests.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ CUDA.allowscalar() do
231231
end
232232
end
233233

234-
234+
235235
# Tests for Enzyme extension
236236
if group == :enzyme || group == :all
237237
@testset "Enzyme extension tests" begin
@@ -281,10 +281,16 @@ CUDA.allowscalar() do
281281
end
282282
end
283283

284+
# Tests for XESMF extension
285+
if group == :xesmf || group == :all
286+
@testset "XESMF extension tests" begin
287+
include("test_xesmf.jl")
288+
end
289+
end
290+
284291
if group == :convergence
285292
include("test_convergence.jl")
286293
end
287294
end
288295

289296
end #CUDA.allowscalar()
290-

test/test_xesmf.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
include("dependencies_for_runtests.jl")
2+
3+
using XESMF
4+
using SparseArrays
5+
using LinearAlgebra
6+
7+
z = (-1, 0)
8+
southernmost_latitude = -80
9+
radius = Oceananigans.Grids.R_Earth
10+
11+
llg_coarse = LatitudeLongitudeGrid(; size=(176, 88, 1),
12+
longitude=(0, 360),
13+
latitude=(southernmost_latitude, 90),
14+
z, radius)
15+
16+
llg_fine = LatitudeLongitudeGrid(; size=(360, 180, 1),
17+
longitude=(0, 360),
18+
latitude=(southernmost_latitude, 90),
19+
z, radius)
20+
21+
tg = TripolarGrid(; size=(360, 170, 1), z, southernmost_latitude, radius)
22+
23+
@testset "XESMF extension" begin
24+
25+
for (src_grid, dst_grid) in ((llg_coarse, llg_fine),
26+
(llg_fine, llg_coarse),
27+
(tg, llg_fine))
28+
29+
@info " Regridding from $(nameof(typeof(src_grid))) to $(nameof(typeof(dst_grid)))"
30+
31+
src_field = CenterField(src_grid)
32+
dst_field = CenterField(dst_grid)
33+
34+
λ₀, φ₀ = 150, 30. # degrees
35+
width = 12 # degrees
36+
set!(src_field, (λ, φ, z) -> exp(-((λ - λ₀)^2 +- φ₀)^2) / 2width^2))
37+
38+
R = XESMF.Regridder(dst_field, src_field)
39+
@test R.weights isa SparseMatrixCSC
40+
41+
regrid!(dst_field, R, src_field)
42+
43+
# ∫ dst_field dA = ∫ src_field dA
44+
@test isapprox(first(Field(Integral(dst_field, dims=(1, 2)))),
45+
first(Field(Integral(src_field, dims=(1, 2)))), rtol=1e-4)
46+
end
47+
end

0 commit comments

Comments
 (0)