Skip to content

Commit a2b729e

Browse files
committed
Add some CUDA compatibility
1 parent 052a0de commit a2b729e

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

ext/InterpolationsRegridderExt.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,20 @@ function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
9494
# horizontal separately
9595
@warn "Ignoring boundary conditions, implementing Periodic, Flat, Flat"
9696

97-
return map(regridder.coordinates) do coord
98-
ClimaComms.allowscalar(
99-
interpolation_3d_z,
100-
device,
101-
data,
102-
dimensions_FT...,
103-
totuple(coord)...,
104-
)
97+
adapted_data = Adapt.adapt(ClimaComms.array_type(regridder.target_space), data)
98+
xs, ys, zs = dimensions_FT
99+
adapted_xs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), xs)
100+
adapted_ys = Adapt.adapt(ClimaComms.array_type(regridder.target_space), ys)
101+
adapted_zs = Adapt.adapt(ClimaComms.array_type(regridder.target_space), zs)
102+
103+
return ClimaComms.allowscalar(ClimaComms.device(regridder.target_space)) do
104+
map(regridder.coordinates) do coord
105+
interpolation_3d_z(
106+
adapted_data,
107+
adapted_xs, adapted_ys, adapted_zs,
108+
totuple(coord)...,
109+
)
110+
end
105111
end
106112
else
107113
# Make a linear spline
@@ -157,7 +163,8 @@ function interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
157163
# end
158164

159165
# Find nearest neighbors
160-
target_x = mod(target_x, maximum(xs) - minimum(xs))
166+
x_period = xs[end] - xs[begin]
167+
target_x = mod(target_x, x_period)
161168

162169
x_index = searchsortedfirst(xs, target_x)
163170
y_index = searchsortedfirst(ys, target_y)
@@ -174,15 +181,16 @@ function interpolation_3d_z(data, xs, ys, zs, target_x, target_y, target_z)
174181
end
175182

176183
# Interpolate in z-direction
177-
z00 = zs[x0_index, y0_index, :]
178-
z01 = zs[x0_index, y1_index, :]
179-
z10 = zs[x1_index, y0_index, :]
180-
z11 = zs[x1_index, y1_index, :]
181-
182-
f00 = linear_interp_z(data[x0_index, y0_index, :], z00, target_z)
183-
f01 = linear_interp_z(data[x0_index, y1_index, :], z01, target_z)
184-
f10 = linear_interp_z(data[x1_index, y0_index, :], z10, target_z)
185-
f11 = linear_interp_z(data[x1_index, y1_index, :], z11, target_z)
184+
185+
z00 = @view zs[x0_index, y0_index, :]
186+
z01 = @view zs[x0_index, y1_index, :]
187+
z10 = @view zs[x1_index, y0_index, :]
188+
z11 = @view zs[x1_index, y1_index, :]
189+
190+
f00 = linear_interp_z(view(data,x0_index, y0_index, :), z00, target_z)
191+
f01 = linear_interp_z(view(data,x0_index, y1_index, :), z01, target_z)
192+
f10 = linear_interp_z(view(data,x1_index, y0_index, :), z10, target_z)
193+
f11 = linear_interp_z(view(data,x1_index, y1_index, :), z11, target_z)
186194

187195
# Bilinear interpolation in x-y plane
188196
val = bilinear_interp(

0 commit comments

Comments
 (0)