Skip to content

Commit bb81b40

Browse files
committed
Try simpler bilinear remapping (for use in netcdf diagnostics)
1 parent 3c150fa commit bb81b40

File tree

8 files changed

+904
-63
lines changed

8 files changed

+904
-63
lines changed

docs/src/remapping.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ given `field`. To obtain such coordinates, you can call the
4545
functions. These functions return an `Array` with the coordinates over which
4646
interpolation will occur. These arrays are of type `Geometry.Point`s.
4747

48-
By default, vertical interpolation is switched off and the `field` is evaluated
49-
directly on the levels.
48+
By default, vertical interpolation is off (field evaluated on levels). Horizontal
49+
interpolation: `:spectral` (default; uses spectral element quadrature weights) or `:bilinear`:
50+
51+
```julia
52+
interpolated_array = Remapping.interpolate(field; horizontal_method = :bilinear)
53+
```
5054

5155
`ClimaCore.Remapping.interpolate` allocates new output arrays. As such, it is
5256
not suitable for performance-critical applications.
@@ -76,9 +80,14 @@ hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts]
7680
zcoords = [Geometry.ZPoint(z) for z in zpts]
7781

7882
interpolated_array = interpolate(field, hcoords, zcoords)
83+
# Bilinear: interpolate(field, hcoords, zcoords; horizontal_method = :bilinear)
7984
```
8085
The output is defined on the Cartesian product of `hcoords` with `zcoords`.
8186

87+
#### Diagnostics and NetCDF writers
88+
89+
Pass `horizontal_method` through when remapping for output (e.g. `Remapping.interpolate(..., horizontal_method = :bilinear)` or `Remapper(..., horizontal_method = :bilinear)`).
90+
8291
If the default target coordinates are being used, it is possible to broadcast
8392
`ClimaCore.Geometry.components` to extract them as a vector of tuples (and then
8493
broadcast `getindex` to extract the respective coordinates as vectors).

examples/remap_visualization.jl

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Visualize bilinear and spectral interpolation remapping: slotted cylinder (Zalesak)
2+
#
3+
# Compares bilinear vs spectral horizontal remap on the slotted cylinder test case
4+
# (disk with rectangular slot; f ∈ {0, 1}). Parameters highlight spectral (Lagrange)
5+
# overshoot/undershoot vs bilinear.
6+
#
7+
# Run from repo root with:
8+
# julia --project=examples examples/remap_visualization.jl
9+
# or with the main project:
10+
# julia --project=. examples/remap_visualization.jl
11+
12+
using ClimaComms
13+
using ClimaCore:
14+
Geometry,
15+
Domains,
16+
Meshes,
17+
Topologies,
18+
Spaces,
19+
Fields,
20+
Remapping,
21+
Quadratures
22+
using CairoMakie
23+
24+
device = ClimaComms.CPUSingleThreaded()
25+
26+
nelements_horz = 6 # horizontal elements per dimension
27+
Nq = 4 # GLL points per dimension
28+
n_interp = 24 # target grid resolution for interpolation
29+
30+
# Slotted cylinder (Zalesak): disk with rectangular slot; f ∈ {0, 1}
31+
slot_radius = 0.15
32+
slot_cx, slot_cy = 0.5, 0.5
33+
slot_half_width = 0.025
34+
slot_y_hi = slot_cy + slot_radius
35+
36+
# --- Domain: square [0, 1] × [0, 1] (periodic) ---
37+
horzdomain = Domains.RectangleDomain(
38+
Geometry.XPoint(0.0) .. Geometry.XPoint(1.0),
39+
Geometry.YPoint(0.0) .. Geometry.YPoint(1.0),
40+
x1periodic = true,
41+
x2periodic = true,
42+
)
43+
44+
# --- Vertical: single layer ---
45+
vertdomain = Domains.IntervalDomain(
46+
Geometry.ZPoint(0.0),
47+
Geometry.ZPoint(1.0);
48+
boundary_names = (:bottom, :top),
49+
)
50+
vertmesh = Meshes.IntervalMesh(vertdomain, nelems = 1)
51+
verttopo = Topologies.IntervalTopology(ClimaComms.SingletonCommsContext(device), vertmesh)
52+
vert_center_space = Spaces.CenterFiniteDifferenceSpace(verttopo)
53+
54+
# --- Horizontal: spectral elements ---
55+
quad = Quadratures.GLL{Nq}()
56+
horzmesh = Meshes.RectilinearMesh(horzdomain, nelements_horz, nelements_horz)
57+
horztopology = Topologies.Topology2D(ClimaComms.SingletonCommsContext(device), horzmesh)
58+
horzspace = Spaces.SpectralElementSpace2D(horztopology, quad)
59+
hv_center_space = Spaces.ExtrudedFiniteDifferenceSpace(horzspace, vert_center_space)
60+
61+
# --- Slotted cylinder field ---
62+
coords = Fields.coordinate_field(hv_center_space)
63+
function slotted_cylinder(x, y)
64+
in_disk = (x - slot_cx)^2 + (y - slot_cy)^2 <= slot_radius^2
65+
in_slot = (abs(x - slot_cx) <= slot_half_width) && (y >= slot_cy) && (y <= slot_y_hi)
66+
return (in_disk && !in_slot) ? 1.0 : 0.0
67+
end
68+
field = @. slotted_cylinder(coords.x, coords.y)
69+
Spaces.weighted_dss!(field)
70+
71+
# --- Target grid: uniform n_interp×n_interp, single vertical level ---
72+
xpts = range(Geometry.XPoint(0.0), Geometry.XPoint(1.0), length = n_interp)
73+
ypts = range(Geometry.YPoint(0.0), Geometry.YPoint(1.0), length = n_interp)
74+
zpts = range(Geometry.ZPoint(0.5), Geometry.ZPoint(0.5), length = 1)
75+
76+
# --- Interpolate: bilinear and spectral ---
77+
interp_bilinear =
78+
Remapping.interpolate_array(field, xpts, ypts, zpts; horizontal_method = :bilinear)
79+
interp_spectral =
80+
Remapping.interpolate_array(field, xpts, ypts, zpts; horizontal_method = :spectral)
81+
interp_bilinear_2d = interp_bilinear[:, :, 1]
82+
interp_spectral_2d = interp_spectral[:, :, 1]
83+
err_bilinear_spectral = interp_bilinear_2d .- interp_spectral_2d
84+
85+
# --- Non-negativity stats (source f ∈ {0, 1}) ---
86+
min_bilinear, max_bilinear = extrema(interp_bilinear_2d)
87+
min_spectral, max_spectral = extrema(interp_spectral_2d)
88+
n_neg = count(<(0), interp_spectral_2d)
89+
n_gt1 = count(>(1), interp_spectral_2d)
90+
@info "Slotted cylinder: non-negativity (source f ∈ {0,1})" bilinear_min = min_bilinear bilinear_max =
91+
max_bilinear spectral_min = min_spectral spectral_max = max_spectral spectral_below_0 =
92+
n_neg spectral_above_1 = n_gt1
93+
94+
# --- Raw spectral element grid (GLL nodes, v=1) ---
95+
x_se = Float64[]
96+
y_se = Float64[]
97+
vals_se = Float64[]
98+
Fields.byslab(hv_center_space) do slabidx
99+
slabidx.v == 1 || return
100+
x_data = parent(Fields.slab(coords.x, slabidx))
101+
y_data = parent(Fields.slab(coords.y, slabidx))
102+
f_data = parent(Fields.slab(field, slabidx))
103+
for j in 1:Nq, i in 1:Nq
104+
push!(x_se, x_data[i, j, 1])
105+
push!(y_se, y_data[i, j, 1])
106+
push!(vals_se, f_data[i, j, 1])
107+
end
108+
end
109+
110+
x_plot = [p.x for p in xpts]
111+
y_plot = [p.y for p in ypts]
112+
boundary_pos = (0:nelements_horz) ./ nelements_horz
113+
114+
# --- Figure: bilinear | spectral | error; row 2 = raw GLL nodes ---
115+
fig = Figure(size = (1200, 800))
116+
117+
ax1 = Axis(fig[1, 1], title = "Bilinear ($n_interp×$n_interp)", xlabel = "x", ylabel = "y")
118+
hm1 = heatmap!(
119+
ax1,
120+
x_plot,
121+
y_plot,
122+
interp_bilinear_2d';
123+
colorrange = (0, 1),
124+
colormap = :viridis,
125+
lowclip = :orange,
126+
highclip = :red,
127+
)
128+
Colorbar(fig[1, 2], hm1; label = "value")
129+
130+
ax2 = Axis(fig[1, 3], title = "Spectral ($n_interp×$n_interp)", xlabel = "x", ylabel = "y")
131+
hm2 = heatmap!(
132+
ax2,
133+
x_plot,
134+
y_plot,
135+
interp_spectral_2d';
136+
colorrange = (0, 1),
137+
colormap = :viridis,
138+
lowclip = :orange,
139+
highclip = :red,
140+
)
141+
Colorbar(fig[1, 4], hm2; label = "value")
142+
143+
ax3 = Axis(
144+
fig[1, 5],
145+
title = "Error (bilinear − spectral)",
146+
xlabel = "x",
147+
ylabel = "y",
148+
)
149+
erange = extrema(err_bilinear_spectral)
150+
hm3 = heatmap!(
151+
ax3,
152+
x_plot,
153+
y_plot,
154+
err_bilinear_spectral';
155+
colorrange = erange,
156+
colormap = :RdBu,
157+
)
158+
Colorbar(fig[1, 6], hm3; label = "error")
159+
160+
ax_se = Axis(
161+
fig[2, 1],
162+
title = "Raw spectral element grid (GLL nodes)",
163+
xlabel = "x",
164+
ylabel = "y",
165+
)
166+
sc_se = scatter!(
167+
ax_se,
168+
y_se,
169+
x_se;
170+
color = vals_se,
171+
colorrange = (0, 1),
172+
colormap = :viridis,
173+
lowclip = :orange,
174+
highclip = :red,
175+
markersize = 8,
176+
)
177+
vlines!(ax_se, boundary_pos; color = :pink, linewidth = 2)
178+
hlines!(ax_se, boundary_pos; color = :pink, linewidth = 2)
179+
limits!(ax_se, 0, 1, 0, 1)
180+
Colorbar(fig[2, 2], sc_se; label = "value")
181+
182+
outdir = joinpath(@__DIR__, "output")
183+
mkpath(outdir)
184+
outpath = joinpath(outdir, "remap_slotted_cylinder_$(n_interp)x$(n_interp).png")
185+
save(outpath, fig)
186+
@info "Saved to $outpath"

ext/cuda/remapping_distributed.jl

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,160 @@ import CUDA
44
using CUDA: @cuda
55
import ClimaCore.Remapping: _set_interpolated_values_device!
66

7+
# Bilinear in [0,1]²: (1-s)(1-t)*c11 + s*(1-t)*c21 + (1-s)*t*c12 + s*t*c22
8+
@inline _bilinear(c11, c21, c22, c12, s, t) =
9+
(1 - s) * (1 - t) * c11 + s * (1 - t) * c21 + (1 - s) * t * c12 + s * t * c22
10+
11+
# Bilinear path on GPU: pure GPU kernels (no scalar indexing).
12+
# 2-point cell (i..i+1, j..j+1) containing (ξ1, ξ2); interpolate between nodes within the element.
13+
function ClimaCore.Remapping._set_interpolated_values_bilinear!(
14+
out::CUDA.CuArray,
15+
fields::AbstractArray{<:Fields.Field},
16+
scratch_corners,
17+
local_horiz_indices,
18+
vert_interpolation_weights::AbstractArray,
19+
vert_bounding_indices::AbstractArray,
20+
local_bilinear_s,
21+
local_bilinear_t,
22+
local_bilinear_i,
23+
local_bilinear_j,
24+
)
25+
field_values = tuple(map(f -> Fields.field_values(f), fields)...)
26+
num_horiz = length(local_horiz_indices)
27+
num_vert = length(vert_bounding_indices)
28+
num_fields = length(field_values)
29+
nitems = length(out)
30+
args = (
31+
out,
32+
local_horiz_indices,
33+
local_bilinear_s,
34+
local_bilinear_t,
35+
local_bilinear_i,
36+
local_bilinear_j,
37+
vert_interpolation_weights,
38+
vert_bounding_indices,
39+
field_values,
40+
)
41+
threads = threads_via_occupancy(set_interpolated_values_bilinear_3d_kernel!, args)
42+
p = linear_partition(nitems, threads)
43+
auto_launch!(
44+
set_interpolated_values_bilinear_3d_kernel!,
45+
args;
46+
threads_s = (p.threads,),
47+
blocks_s = (p.blocks,),
48+
)
49+
end
50+
51+
function set_interpolated_values_bilinear_3d_kernel!(
52+
out,
53+
local_horiz_indices,
54+
local_bilinear_s,
55+
local_bilinear_t,
56+
local_bilinear_i,
57+
local_bilinear_j,
58+
vert_interpolation_weights,
59+
vert_bounding_indices,
60+
field_values,
61+
)
62+
num_horiz = length(local_horiz_indices)
63+
num_vert = length(vert_bounding_indices)
64+
num_fields = length(field_values)
65+
inds = (num_horiz, num_vert, num_fields)
66+
i_thread =
67+
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
68+
1 i_thread prod(inds) || return nothing
69+
(i_out, j_v, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I
70+
@inbounds begin
71+
CI = CartesianIndex
72+
h = local_horiz_indices[i_out]
73+
v_lo, v_hi = vert_bounding_indices[j_v]
74+
A, B = vert_interpolation_weights[j_v]
75+
s = local_bilinear_s[i_out]
76+
t = local_bilinear_t[i_out]
77+
ii = local_bilinear_i[i_out]
78+
jj = local_bilinear_j[i_out]
79+
fvals = field_values[k]
80+
# Four nodes of 2-point cell: (ii,jj), (ii+1,jj), (ii+1,jj+1), (ii,jj+1)
81+
c11 = A * fvals[CI(ii, jj, 1, v_lo, h)] + B * fvals[CI(ii, jj, 1, v_hi, h)]
82+
c21 = A * fvals[CI(ii + 1, jj, 1, v_lo, h)] + B * fvals[CI(ii + 1, jj, 1, v_hi, h)]
83+
c22 =
84+
A * fvals[CI(ii + 1, jj + 1, 1, v_lo, h)] +
85+
B * fvals[CI(ii + 1, jj + 1, 1, v_hi, h)]
86+
c12 = A * fvals[CI(ii, jj + 1, 1, v_lo, h)] + B * fvals[CI(ii, jj + 1, 1, v_hi, h)]
87+
out[i_out, j_v, k] = _bilinear(c11, c21, c22, c12, s, t)
88+
end
89+
return nothing
90+
end
91+
92+
function ClimaCore.Remapping._set_interpolated_values_bilinear!(
93+
out::CUDA.CuArray,
94+
fields::AbstractArray{<:Fields.Field},
95+
scratch_corners,
96+
local_horiz_indices,
97+
::Nothing,
98+
::Nothing,
99+
local_bilinear_s,
100+
local_bilinear_t,
101+
local_bilinear_i,
102+
local_bilinear_j,
103+
)
104+
field_values = tuple(map(f -> Fields.field_values(f), fields)...)
105+
num_horiz = length(local_horiz_indices)
106+
num_fields = length(field_values)
107+
nitems = length(out)
108+
args = (
109+
out,
110+
local_horiz_indices,
111+
local_bilinear_s,
112+
local_bilinear_t,
113+
local_bilinear_i,
114+
local_bilinear_j,
115+
field_values,
116+
)
117+
threads = threads_via_occupancy(set_interpolated_values_bilinear_2d_kernel!, args)
118+
p = linear_partition(nitems, threads)
119+
auto_launch!(
120+
set_interpolated_values_bilinear_2d_kernel!,
121+
args;
122+
threads_s = (p.threads,),
123+
blocks_s = (p.blocks,),
124+
)
125+
end
126+
127+
function set_interpolated_values_bilinear_2d_kernel!(
128+
out,
129+
local_horiz_indices,
130+
local_bilinear_s,
131+
local_bilinear_t,
132+
local_bilinear_i,
133+
local_bilinear_j,
134+
field_values,
135+
)
136+
num_horiz = length(local_horiz_indices)
137+
num_fields = length(field_values)
138+
inds = (num_horiz, num_fields)
139+
i_thread =
140+
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
141+
1 i_thread prod(inds) || return nothing
142+
(i_out, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I
143+
@inbounds begin
144+
CI = CartesianIndex
145+
h = local_horiz_indices[i_out]
146+
s = local_bilinear_s[i_out]
147+
t = local_bilinear_t[i_out]
148+
ii = local_bilinear_i[i_out]
149+
jj = local_bilinear_j[i_out]
150+
fvals = field_values[k]
151+
# Four nodes of 2-point cell: (ii,jj), (ii+1,jj), (ii+1,jj+1), (ii,jj+1)
152+
c11 = fvals[CI(ii, jj, 1, 1, h)]
153+
c21 = fvals[CI(ii + 1, jj, 1, 1, h)]
154+
c22 = fvals[CI(ii + 1, jj + 1, 1, 1, h)]
155+
c12 = fvals[CI(ii, jj + 1, 1, 1, h)]
156+
out[i_out, k] = _bilinear(c11, c21, c22, c12, s, t)
157+
end
158+
return nothing
159+
end
160+
7161

8162
function _set_interpolated_values_device!(
9163
out::AbstractArray,

0 commit comments

Comments
 (0)