Skip to content

Commit 5b876f1

Browse files
committed
Try simpler bilinear remapping (for use in netcdf diagnostics)
1 parent 3c150fa commit 5b876f1

File tree

8 files changed

+906
-63
lines changed

8 files changed

+906
-63
lines changed

docs/src/remapping.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,12 @@ given `field`. To obtain such coordinates, you can call the
4545
functions. These functions return an `Array` with the coordinates over which
4646
interpolation will occur. These arrays are of type `Geometry.Point`s.
4747

48-
By default, vertical interpolation is switched off and the `field` is evaluated
49-
directly on the levels.
48+
By default, vertical interpolation is off (field evaluated on levels). Horizontal
49+
interpolation: `:spectral` (default; uses spectral element quadrature weights) or `:bilinear`:
50+
51+
```julia
52+
interpolated_array = Remapping.interpolate(field; horizontal_method = :bilinear)
53+
```
5054

5155
`ClimaCore.Remapping.interpolate` allocates new output arrays. As such, it is
5256
not suitable for performance-critical applications.
@@ -76,9 +80,14 @@ hcoords = [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts]
7680
zcoords = [Geometry.ZPoint(z) for z in zpts]
7781

7882
interpolated_array = interpolate(field, hcoords, zcoords)
83+
# Bilinear: interpolate(field, hcoords, zcoords; horizontal_method = :bilinear)
7984
```
8085
The output is defined on the Cartesian product of `hcoords` with `zcoords`.
8186

87+
#### Diagnostics and NetCDF writers
88+
89+
Pass `horizontal_method` through when remapping for output (e.g. `Remapping.interpolate(..., horizontal_method = :bilinear)` or `Remapper(..., horizontal_method = :bilinear)`).
90+
8291
If the default target coordinates are being used, it is possible to broadcast
8392
`ClimaCore.Geometry.components` to extract them as a vector of tuples (and then
8493
broadcast `getindex` to extract the respective coordinates as vectors).

examples/remap_visualization.jl

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Visualize bilinear and spectral interpolation remapping: slotted cylinder (Zalesak)
2+
#
3+
# Compares bilinear vs spectral horizontal remap on the slotted cylinder test case
4+
# (disk with rectangular slot; f ∈ {0, 1}). Parameters highlight spectral (Lagrange)
5+
# overshoot/undershoot vs bilinear.
6+
#
7+
# Run from repo root with:
8+
# julia --project=examples examples/remap_visualization.jl
9+
# or with the main project:
10+
# julia --project=. examples/remap_visualization.jl
11+
12+
using ClimaComms
13+
using ClimaCore:
14+
Geometry,
15+
Domains,
16+
Meshes,
17+
Topologies,
18+
Spaces,
19+
Fields,
20+
Remapping,
21+
Quadratures
22+
using CairoMakie
23+
24+
device = ClimaComms.CPUSingleThreaded()
25+
26+
nelements_horz = 6 # horizontal elements per dimension
27+
Nq = 4 # GLL points per dimension
28+
n_interp = 24 # target grid resolution for interpolation
29+
30+
# Slotted cylinder (Zalesak): disk with rectangular slot; f ∈ {0, 1}
31+
slot_radius = 0.15
32+
slot_cx, slot_cy = 0.5, 0.5
33+
slot_half_width = 0.025
34+
slot_y_hi = slot_cy + slot_radius
35+
36+
# --- Domain: square [0, 1] × [0, 1] (periodic) ---
37+
horzdomain = Domains.RectangleDomain(
38+
Geometry.XPoint(0.0) .. Geometry.XPoint(1.0),
39+
Geometry.YPoint(0.0) .. Geometry.YPoint(1.0),
40+
x1periodic = true,
41+
x2periodic = true,
42+
)
43+
44+
# --- Vertical: single layer ---
45+
vertdomain = Domains.IntervalDomain(
46+
Geometry.ZPoint(0.0),
47+
Geometry.ZPoint(1.0);
48+
boundary_names = (:bottom, :top),
49+
)
50+
vertmesh = Meshes.IntervalMesh(vertdomain, nelems = 1)
51+
verttopo = Topologies.IntervalTopology(ClimaComms.SingletonCommsContext(device), vertmesh)
52+
vert_center_space = Spaces.CenterFiniteDifferenceSpace(verttopo)
53+
54+
# --- Horizontal: spectral elements ---
55+
quad = Quadratures.GLL{Nq}()
56+
horzmesh = Meshes.RectilinearMesh(horzdomain, nelements_horz, nelements_horz)
57+
horztopology = Topologies.Topology2D(ClimaComms.SingletonCommsContext(device), horzmesh)
58+
horzspace = Spaces.SpectralElementSpace2D(horztopology, quad)
59+
hv_center_space = Spaces.ExtrudedFiniteDifferenceSpace(horzspace, vert_center_space)
60+
61+
# --- Slotted cylinder field ---
62+
coords = Fields.coordinate_field(hv_center_space)
63+
function slotted_cylinder(x, y)
64+
in_disk = (x - slot_cx)^2 + (y - slot_cy)^2 <= slot_radius^2
65+
in_slot = (abs(x - slot_cx) <= slot_half_width) && (y >= slot_cy) && (y <= slot_y_hi)
66+
return (in_disk && !in_slot) ? 1.0 : 0.0
67+
end
68+
field = @. slotted_cylinder(coords.x, coords.y)
69+
Spaces.weighted_dss!(field)
70+
71+
# --- Target grid: uniform n_interp×n_interp, single vertical level ---
72+
xpts = range(Geometry.XPoint(0.0), Geometry.XPoint(1.0), length = n_interp)
73+
ypts = range(Geometry.YPoint(0.0), Geometry.YPoint(1.0), length = n_interp)
74+
zpts = range(Geometry.ZPoint(0.5), Geometry.ZPoint(0.5), length = 1)
75+
76+
# --- Interpolate: bilinear and spectral ---
77+
interp_bilinear =
78+
Remapping.interpolate_array(field, xpts, ypts, zpts; horizontal_method = :bilinear)
79+
interp_spectral =
80+
Remapping.interpolate_array(field, xpts, ypts, zpts; horizontal_method = :spectral)
81+
interp_bilinear_2d = interp_bilinear[:, :, 1]
82+
interp_spectral_2d = interp_spectral[:, :, 1]
83+
err_bilinear_spectral = interp_bilinear_2d .- interp_spectral_2d
84+
85+
# --- Non-negativity stats (source f ∈ {0, 1}) ---
86+
min_bilinear, max_bilinear = extrema(interp_bilinear_2d)
87+
min_spectral, max_spectral = extrema(interp_spectral_2d)
88+
n_neg = count(<(0), interp_spectral_2d)
89+
n_gt1 = count(>(1), interp_spectral_2d)
90+
@info "Slotted cylinder: non-negativity (source f ∈ {0,1})" bilinear_min = min_bilinear bilinear_max =
91+
max_bilinear spectral_min = min_spectral spectral_max = max_spectral spectral_below_0 =
92+
n_neg spectral_above_1 = n_gt1
93+
94+
# --- Raw spectral element grid (GLL nodes, v=1) ---
95+
x_se = Float64[]
96+
y_se = Float64[]
97+
vals_se = Float64[]
98+
Fields.byslab(hv_center_space) do slabidx
99+
slabidx.v == 1 || return
100+
x_data = parent(Fields.slab(coords.x, slabidx))
101+
y_data = parent(Fields.slab(coords.y, slabidx))
102+
f_data = parent(Fields.slab(field, slabidx))
103+
for j in 1:Nq, i in 1:Nq
104+
push!(x_se, x_data[i, j, 1])
105+
push!(y_se, y_data[i, j, 1])
106+
push!(vals_se, f_data[i, j, 1])
107+
end
108+
end
109+
110+
x_plot = [p.x for p in xpts]
111+
y_plot = [p.y for p in ypts]
112+
boundary_pos = (0:nelements_horz) ./ nelements_horz
113+
114+
# --- Figure: bilinear | spectral | error; row 2 = raw GLL nodes ---
115+
fig = Figure(size = (1200, 800))
116+
117+
ax1 = Axis(fig[1, 1], title = "Bilinear ($n_interp×$n_interp)", xlabel = "x", ylabel = "y")
118+
hm1 = heatmap!(
119+
ax1,
120+
x_plot,
121+
y_plot,
122+
interp_bilinear_2d';
123+
colorrange = (0, 1),
124+
colormap = :viridis,
125+
lowclip = :orange,
126+
highclip = :red,
127+
)
128+
Colorbar(fig[1, 2], hm1; label = "value")
129+
130+
ax2 = Axis(fig[1, 3], title = "Spectral ($n_interp×$n_interp)", xlabel = "x", ylabel = "y")
131+
hm2 = heatmap!(
132+
ax2,
133+
x_plot,
134+
y_plot,
135+
interp_spectral_2d';
136+
colorrange = (0, 1),
137+
colormap = :viridis,
138+
lowclip = :orange,
139+
highclip = :red,
140+
)
141+
Colorbar(fig[1, 4], hm2; label = "value")
142+
143+
ax3 = Axis(
144+
fig[1, 5],
145+
title = "Error (bilinear − spectral)",
146+
xlabel = "x",
147+
ylabel = "y",
148+
)
149+
erange = extrema(err_bilinear_spectral)
150+
hm3 = heatmap!(
151+
ax3,
152+
x_plot,
153+
y_plot,
154+
err_bilinear_spectral';
155+
colorrange = erange,
156+
colormap = :RdBu,
157+
)
158+
Colorbar(fig[1, 6], hm3; label = "error")
159+
160+
ax_se = Axis(
161+
fig[2, 1],
162+
title = "Raw spectral element grid (GLL nodes)",
163+
xlabel = "x",
164+
ylabel = "y",
165+
)
166+
sc_se = scatter!(
167+
ax_se,
168+
y_se,
169+
x_se;
170+
color = vals_se,
171+
colorrange = (0, 1),
172+
colormap = :viridis,
173+
lowclip = :orange,
174+
highclip = :red,
175+
markersize = 8,
176+
)
177+
vlines!(ax_se, boundary_pos; color = :pink, linewidth = 2)
178+
hlines!(ax_se, boundary_pos; color = :pink, linewidth = 2)
179+
limits!(ax_se, 0, 1, 0, 1)
180+
Colorbar(fig[2, 2], sc_se; label = "value")
181+
182+
outdir = joinpath(@__DIR__, "output")
183+
mkpath(outdir)
184+
outpath = joinpath(outdir, "remap_slotted_cylinder_$(n_interp)x$(n_interp).png")
185+
save(outpath, fig)
186+
@info "Saved to $outpath"

ext/cuda/remapping_distributed.jl

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,161 @@ import CUDA
44
using CUDA: @cuda
55
import ClimaCore.Remapping: _set_interpolated_values_device!
66

7+
# Bilinear in (s,t) ∈ [0,1]² (local to 2-point cell; reference element is [-1,1]²):
8+
# (1-s)(1-t)*c11 + s*(1-t)*c21 + (1-s)*t*c12 + s*t*c22
9+
@inline _bilinear(c11, c21, c22, c12, s, t) =
10+
(1 - s) * (1 - t) * c11 + s * (1 - t) * c21 + (1 - s) * t * c12 + s * t * c22
11+
12+
# Bilinear path on GPU: pure GPU kernels (no scalar indexing).
13+
# 2-point cell (i..i+1, j..j+1) containing (ξ1, ξ2); interpolate between nodes within the element.
14+
function ClimaCore.Remapping._set_interpolated_values_bilinear!(
15+
out::CUDA.CuArray,
16+
fields::AbstractArray{<:Fields.Field},
17+
scratch_corners,
18+
local_horiz_indices,
19+
vert_interpolation_weights::AbstractArray,
20+
vert_bounding_indices::AbstractArray,
21+
local_bilinear_s,
22+
local_bilinear_t,
23+
local_bilinear_i,
24+
local_bilinear_j,
25+
)
26+
field_values = tuple(map(f -> Fields.field_values(f), fields)...)
27+
num_horiz = length(local_horiz_indices)
28+
num_vert = length(vert_bounding_indices)
29+
num_fields = length(field_values)
30+
nitems = length(out)
31+
args = (
32+
out,
33+
local_horiz_indices,
34+
local_bilinear_s,
35+
local_bilinear_t,
36+
local_bilinear_i,
37+
local_bilinear_j,
38+
vert_interpolation_weights,
39+
vert_bounding_indices,
40+
field_values,
41+
)
42+
threads = threads_via_occupancy(set_interpolated_values_bilinear_3d_kernel!, args)
43+
p = linear_partition(nitems, threads)
44+
auto_launch!(
45+
set_interpolated_values_bilinear_3d_kernel!,
46+
args;
47+
threads_s = (p.threads,),
48+
blocks_s = (p.blocks,),
49+
)
50+
end
51+
52+
function set_interpolated_values_bilinear_3d_kernel!(
53+
out,
54+
local_horiz_indices,
55+
local_bilinear_s,
56+
local_bilinear_t,
57+
local_bilinear_i,
58+
local_bilinear_j,
59+
vert_interpolation_weights,
60+
vert_bounding_indices,
61+
field_values,
62+
)
63+
num_horiz = length(local_horiz_indices)
64+
num_vert = length(vert_bounding_indices)
65+
num_fields = length(field_values)
66+
inds = (num_horiz, num_vert, num_fields)
67+
i_thread =
68+
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
69+
1 i_thread prod(inds) || return nothing
70+
(i_out, j_v, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I
71+
@inbounds begin
72+
CI = CartesianIndex
73+
h = local_horiz_indices[i_out]
74+
v_lo, v_hi = vert_bounding_indices[j_v]
75+
A, B = vert_interpolation_weights[j_v]
76+
s = local_bilinear_s[i_out]
77+
t = local_bilinear_t[i_out]
78+
ii = local_bilinear_i[i_out]
79+
jj = local_bilinear_j[i_out]
80+
fvals = field_values[k]
81+
# Four nodes of 2-point cell: (ii,jj), (ii+1,jj), (ii+1,jj+1), (ii,jj+1)
82+
c11 = A * fvals[CI(ii, jj, 1, v_lo, h)] + B * fvals[CI(ii, jj, 1, v_hi, h)]
83+
c21 = A * fvals[CI(ii + 1, jj, 1, v_lo, h)] + B * fvals[CI(ii + 1, jj, 1, v_hi, h)]
84+
c22 =
85+
A * fvals[CI(ii + 1, jj + 1, 1, v_lo, h)] +
86+
B * fvals[CI(ii + 1, jj + 1, 1, v_hi, h)]
87+
c12 = A * fvals[CI(ii, jj + 1, 1, v_lo, h)] + B * fvals[CI(ii, jj + 1, 1, v_hi, h)]
88+
out[i_out, j_v, k] = _bilinear(c11, c21, c22, c12, s, t)
89+
end
90+
return nothing
91+
end
92+
93+
function ClimaCore.Remapping._set_interpolated_values_bilinear!(
94+
out::CUDA.CuArray,
95+
fields::AbstractArray{<:Fields.Field},
96+
scratch_corners,
97+
local_horiz_indices,
98+
::Nothing,
99+
::Nothing,
100+
local_bilinear_s,
101+
local_bilinear_t,
102+
local_bilinear_i,
103+
local_bilinear_j,
104+
)
105+
field_values = tuple(map(f -> Fields.field_values(f), fields)...)
106+
num_horiz = length(local_horiz_indices)
107+
num_fields = length(field_values)
108+
nitems = length(out)
109+
args = (
110+
out,
111+
local_horiz_indices,
112+
local_bilinear_s,
113+
local_bilinear_t,
114+
local_bilinear_i,
115+
local_bilinear_j,
116+
field_values,
117+
)
118+
threads = threads_via_occupancy(set_interpolated_values_bilinear_2d_kernel!, args)
119+
p = linear_partition(nitems, threads)
120+
auto_launch!(
121+
set_interpolated_values_bilinear_2d_kernel!,
122+
args;
123+
threads_s = (p.threads,),
124+
blocks_s = (p.blocks,),
125+
)
126+
end
127+
128+
function set_interpolated_values_bilinear_2d_kernel!(
129+
out,
130+
local_horiz_indices,
131+
local_bilinear_s,
132+
local_bilinear_t,
133+
local_bilinear_i,
134+
local_bilinear_j,
135+
field_values,
136+
)
137+
num_horiz = length(local_horiz_indices)
138+
num_fields = length(field_values)
139+
inds = (num_horiz, num_fields)
140+
i_thread =
141+
(CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
142+
1 i_thread prod(inds) || return nothing
143+
(i_out, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[i_thread].I
144+
@inbounds begin
145+
CI = CartesianIndex
146+
h = local_horiz_indices[i_out]
147+
s = local_bilinear_s[i_out]
148+
t = local_bilinear_t[i_out]
149+
ii = local_bilinear_i[i_out]
150+
jj = local_bilinear_j[i_out]
151+
fvals = field_values[k]
152+
# Four nodes of 2-point cell: (ii,jj), (ii+1,jj), (ii+1,jj+1), (ii,jj+1)
153+
c11 = fvals[CI(ii, jj, 1, 1, h)]
154+
c21 = fvals[CI(ii + 1, jj, 1, 1, h)]
155+
c22 = fvals[CI(ii + 1, jj + 1, 1, 1, h)]
156+
c12 = fvals[CI(ii, jj + 1, 1, 1, h)]
157+
out[i_out, k] = _bilinear(c11, c21, c22, c12, s, t)
158+
end
159+
return nothing
160+
end
161+
7162

8163
function _set_interpolated_values_device!(
9164
out::AbstractArray,

0 commit comments

Comments
 (0)