Skip to content

Commit cbfb6f2

Browse files
authored
Further enhancements of the XESMF extension (#4819)
* tweaks * generalize xesmf regridder for GPUs * fix example * cleanup * add comment
1 parent fc6bbd7 commit cbfb6f2

File tree

5 files changed

+60
-81
lines changed

5 files changed

+60
-81
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ steps:
8787
- "🧅 multi_region"
8888
- "🦧 scripts"
8989
- "👺 enzyme"
90+
- "🍱 xesmf"
9091
- "👹 reactant_1"
9192
- "🎭 reactant_2"
9293
retry:

.github/workflows/ci.yml

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -159,33 +159,3 @@ jobs:
159159
- uses: julia-actions/julia-runtest@v1
160160
env:
161161
TEST_GROUP: "metal"
162-
163-
164-
xesmf:
165-
name: XESMF - Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
166-
runs-on: ${{ matrix.os }}
167-
timeout-minutes: 60
168-
strategy:
169-
fail-fast: false
170-
matrix:
171-
version:
172-
- '1.10'
173-
os:
174-
- ubuntu-latest
175-
arch:
176-
- x64
177-
include:
178-
- os: macOS-latest
179-
arch: aarch64
180-
version: '1.10'
181-
steps:
182-
- uses: actions/checkout@v5
183-
- uses: julia-actions/setup-julia@v2
184-
with:
185-
version: ${{ matrix.version }}
186-
arch: ${{ matrix.arch }}
187-
- uses: julia-actions/cache@v2
188-
- uses: julia-actions/julia-buildpkg@v1
189-
- uses: julia-actions/julia-runtest@v1
190-
env:
191-
TEST_GROUP: "xesmf"

ext/OceananigansXESMFExt.jl

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,23 @@ using Oceananigans.Grids: λnodes, φnodes, Center, Face, total_length
99
import Oceananigans.Fields: regrid!
1010
import XESMF: Regridder, extract_xesmf_coordinates_structure
1111

12-
function x_node_array(x::AbstractVector, Nx, Ny)
13-
return Array(repeat(view(x, 1:Nx), 1, Ny))'
14-
end
15-
function y_node_array(x::AbstractVector, Nx, Ny)
16-
return Array(repeat(view(x, 1:Ny)', Nx, 1))'
17-
end
18-
x_node_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx, 1:Ny))'
12+
# permutedims below is used because Python's xESMF expects
13+
# 2D arrays with (x, y) coordinates with y varying in dim=1 and x varying in dim=2
1914

20-
function x_vertex_array(x::AbstractVector, Nx, Ny)
21-
return Array(repeat(view(x, 1:Nx+1), 1, Ny+1))'
22-
end
23-
function y_vertex_array(x::AbstractVector, Nx, Ny)
24-
return Array(repeat(view(x, 1:Ny+1)', Nx+1, 1))'
25-
end
26-
x_vertex_array(x::AbstractMatrix, Nx, Ny) = Array(view(x, 1:Nx+1, 1:Ny+1))'
15+
node_array::AbstractMatrix, Nx, Ny) = permutedims(view(ξ, 1:Nx, 1:Ny), (2, 1))
16+
vertex_array::AbstractMatrix, Nx, Ny) = permutedims(view(ξ, 1:Nx+1, 1:Ny+1), (2, 1))
17+
18+
x_node_array(x::AbstractVector, Nx, Ny) = permutedims(repeat(view(x, 1:Nx), 1, Ny), (2, 1))
19+
x_node_array(x::AbstractMatrix, Nx, Ny) = node_array(x, Nx, Ny)
20+
21+
y_node_array(y::AbstractVector, Nx, Ny) = repeat(view(y, 1:Ny), 1, Nx)
22+
y_node_array(y::AbstractMatrix, Nx, Ny) = node_array(y, Nx, Ny)
23+
24+
x_vertex_array(x::AbstractVector, Nx, Ny) = permutedims(repeat(view(x, 1:Nx+1), 1, Ny+1), (2, 1))
25+
x_vertex_array(x::AbstractMatrix, Nx, Ny) = vertex_array(x, Nx, Ny)
2726

28-
y_node_array(x::AbstractMatrix, Nx, Ny) = x_node_array(x, Nx, Ny)
29-
y_vertex_array(x::AbstractMatrix, Nx, Ny) = x_vertex_array(x, Nx, Ny)
27+
y_vertex_array(y::AbstractVector, Nx, Ny) = repeat(view(y, 1:Ny+1), 1, Nx+1)
28+
y_vertex_array(y::AbstractMatrix, Nx, Ny) = vertex_array(y, Nx, Ny)
3029

3130
function extract_xesmf_coordinates_structure(dst_field::AbstractField, src_field::AbstractField)
3231

test/runtests.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,13 @@ CUDA.allowscalar() do
252252
end
253253
end
254254

255+
# Tests for XESMF extension
256+
if group == :xesmf || group == :all
257+
@testset "XESMF extension tests" begin
258+
include("test_xesmf.jl")
259+
end
260+
end
261+
255262
if group == :sharding || group == :all
256263
@testset "Sharding Reactant extension tests" begin
257264
# Broken for the moment (trying to fix them in https://github.com/CliMA/Oceananigans.jl/pull/4293)
@@ -281,13 +288,6 @@ CUDA.allowscalar() do
281288
end
282289
end
283290

284-
# Tests for XESMF extension
285-
if group == :xesmf || group == :all
286-
@testset "XESMF extension tests" begin
287-
include("test_xesmf.jl")
288-
end
289-
end
290-
291291
if group == :convergence
292292
include("test_convergence.jl")
293293
end

test/test_xesmf.jl

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,53 @@ using XESMF
44
using SparseArrays
55
using LinearAlgebra
66

7-
z = (-1, 0)
8-
southernmost_latitude = -80
9-
radius = Oceananigans.Grids.R_Earth
7+
for arch in archs
8+
@testset "XESMF extension [$(typeof(arch))]" begin
9+
@info "Testing XESMF regridding [$(typeof(arch))]..."
1010

11-
llg_coarse = LatitudeLongitudeGrid(; size=(176, 88, 1),
12-
longitude=(0, 360),
13-
latitude=(southernmost_latitude, 90),
14-
z, radius)
11+
z = (-1, 0)
12+
southernmost_latitude = -80
13+
radius = Oceananigans.Grids.R_Earth
1514

16-
llg_fine = LatitudeLongitudeGrid(; size=(360, 180, 1),
17-
longitude=(0, 360),
18-
latitude=(southernmost_latitude, 90),
19-
z, radius)
15+
llg_coarse = LatitudeLongitudeGrid(arch; size=(176, 88, 1),
16+
longitude=(0, 360),
17+
latitude=(southernmost_latitude, 90),
18+
z, radius)
2019

21-
tg = TripolarGrid(; size=(360, 170, 1), z, southernmost_latitude, radius)
20+
llg_fine = LatitudeLongitudeGrid(arch; size=(360, 170, 1),
21+
longitude=(0, 360),
22+
latitude=(southernmost_latitude, 90),
23+
z, radius)
2224

23-
@testset "XESMF extension" begin
25+
tg = TripolarGrid(arch; size=(360, 170, 1), z, southernmost_latitude, radius)
2426

25-
for (src_grid, dst_grid) in ((llg_coarse, llg_fine),
26-
(llg_fine, llg_coarse),
27-
(tg, llg_fine))
2827

29-
@info " Regridding from $(nameof(typeof(src_grid))) to $(nameof(typeof(dst_grid)))"
28+
for (src_grid, dst_grid) in ((llg_coarse, llg_fine),
29+
(llg_fine, llg_coarse),
30+
(tg, llg_fine))
3031

31-
src_field = CenterField(src_grid)
32-
dst_field = CenterField(dst_grid)
32+
@info " Regridding from $(nameof(typeof(src_grid))) to $(nameof(typeof(dst_grid)))"
3333

34-
λ₀, φ₀ = 150, 30. # degrees
35-
width = 12 # degrees
36-
set!(src_field, (λ, φ, z) -> exp(-((λ - λ₀)^2 +- φ₀)^2) / 2width^2))
34+
src_field = CenterField(src_grid)
35+
dst_field = CenterField(dst_grid)
3736

38-
R = XESMF.Regridder(dst_field, src_field)
39-
@test R.weights isa SparseMatrixCSC
37+
λ₀, φ₀ = 150, 30. # degrees
38+
width = 12 # degrees
39+
set!(src_field, (λ, φ, z) -> exp(-((λ - λ₀)^2 +- φ₀)^2) / 2width^2))
4040

41-
regrid!(dst_field, R, src_field)
41+
regridder = XESMF.Regridder(dst_field, src_field)
4242

43-
# ∫ dst_field dA = ∫ src_field dA
44-
@test isapprox(first(Field(Integral(dst_field, dims=(1, 2)))),
45-
first(Field(Integral(src_field, dims=(1, 2)))), rtol=1e-4)
43+
if arch isa CPU
44+
@test regridder.weights isa SparseMatrixCSC
45+
elseif arch isa GPU{CUDABackend}
46+
@test regridder.weights isa CUDA.CUSPARSE.CuSparseMatrixCSC
47+
end
48+
49+
regrid!(dst_field, regridder, src_field)
50+
51+
# ∫ dst_field dA ≈ ∫ src_field dA
52+
@test @allowscalar isapprox(first(Field(Integral(dst_field, dims=(1, 2)))),
53+
first(Field(Integral(src_field, dims=(1, 2)))), rtol=1e-4)
54+
end
4655
end
4756
end

0 commit comments

Comments
 (0)