Skip to content

Commit ffca901

Browse files
committed
Preallocate GPU interpolant
The interpolants we in `Interpolations.jl` are described by two arrays: the knots and the coeffs. When `Adapt` is called on these interpolants, CuArrays are allocated on the GPU. For large data, this is inefficient. In this commit, I add a system to avoid these allocations. This is accomplished by add a dictionary to `InterpolationsRegridder`. This dictionary has keys that identify the size of the knots and coefficients and values the adapted splines. When `regrid` is called, we check if we have already allocated some suitable space in this dictionary, if not, we create a new spline, if we do, we write in place. This removes GPU allocations in the hot path (ie, the regridder is used in a time evolution with always the same data and dimensions), while also keeping the flexibility of reusing the same regridder with any input data.
1 parent 5ebbd64 commit ffca901

File tree

1 file changed

+59
-10
lines changed

1 file changed

+59
-10
lines changed

ext/InterpolationsRegridderExt.jl

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ struct InterpolationsRegridder{
1212
SPACE <: ClimaCore.Spaces.AbstractSpace,
1313
FIELD <: ClimaCore.Fields.Field,
1414
BC,
15+
GITP,
1516
} <: Regridders.AbstractRegridder
1617

1718
"""ClimaCore.Space where the output Field will be defined"""
@@ -22,6 +23,14 @@ struct InterpolationsRegridder{
2223

2324
"""Tuple of extrapolation conditions as accepted by Interpolations.jl"""
2425
extrapolation_bc::BC
26+
27+
# This is needed because Adapt moves from CPU to GPU and allocates new memory.
28+
"""Dictionary of preallocated areas of memory where to store the GPU interpolant (if
29+
needed). Every time new data/dimensions are used in regrid, a new entry in the
30+
dictionary is created. The keys of the dictionary a tuple of tuple
31+
`(size(dimensions), size(data))`, with `dimensions` and `data` defined in `regrid`.
32+
"""
33+
_gpuitps::GITP
2534
end
2635

2736
# Note, we swap Lat and Long! This is because according to the CF conventions longitude
@@ -75,13 +84,38 @@ function Regridders.InterpolationsRegridder(
7584
"Number of boundary conditions does not match the number of dimensions",
7685
)
7786

87+
# Let's figure out the type of _gpuitps by creating a simple spline
88+
FT = ClimaCore.Spaces.undertype(target_space)
89+
dimensions = ntuple(_ -> [zero(FT), one(FT)], num_dimensions)
90+
data = zeros(FT, ntuple(_ -> 2, num_dimensions))
91+
itp = _create_linear_spline(FT, data, dimensions, extrapolation_bc)
92+
fake_gpuitp = Adapt.adapt(ClimaComms.array_type(target_space), itp)
93+
gpuitps = Dict((size.(dimensions), size(data)) => fake_gpuitp)
94+
7895
return InterpolationsRegridder(
7996
target_space,
8097
coordinates,
81-
extrapolation_bc
98+
extrapolation_bc,
99+
gpuitps,
100+
)
101+
end
102+
103+
"""
104+
_create_linear_spline(regridder::InterpolationsRegridder, data, dimensions)
105+
106+
Create a linear spline for the given data on the given dimension (on the CPU).
107+
"""
108+
function _create_linear_spline(FT, data, dimensions, extrapolation_bc)
109+
dimensions_FT = map(d -> FT.(d), dimensions)
110+
111+
# Make a linear spline
112+
return Intp.extrapolate(
113+
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
114+
extrapolation_bc,
82115
)
83116
end
84117

118+
85119
"""
86120
regrid(regridder::InterpolationsRegridder, data, dimensions)::Field
87121
@@ -91,16 +125,31 @@ This function is allocating.
91125
"""
92126
function Regridders.regrid(regridder::InterpolationsRegridder, data, dimensions)
93127
FT = ClimaCore.Spaces.undertype(regridder.target_space)
94-
dimensions_FT = map(d -> FT.(d), dimensions)
95-
96-
# Make a linear spline
97-
itp = Intp.extrapolate(
98-
Intp.interpolate(dimensions_FT, FT.(data), Intp.Gridded(Intp.Linear())),
99-
regridder.extrapolation_bc,
100-
)
128+
itp =
129+
_create_linear_spline(FT, data, dimensions, regridder.extrapolation_bc)
130+
131+
key = (size.(dimensions), size(data))
132+
133+
if haskey(regridder._gpuitps, key)
134+
for (k, k_new) in zip(
135+
regridder._gpuitps[key].itp.knots,
136+
Adapt.adapt(
137+
ClimaComms.array_type(regridder.target_space),
138+
itp.itp.knots,
139+
),
140+
)
141+
k .= k_new
142+
end
143+
regridder._gpuitps[key].itp.coefs .= Adapt.adapt(
144+
ClimaComms.array_type(regridder.target_space),
145+
itp.itp.coefs,
146+
)
147+
else
148+
regridder._gpuitps[key] =
149+
Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
150+
end
101151

102-
# Move it to GPU (if needed)
103-
gpuitp = Adapt.adapt(ClimaComms.array_type(regridder.target_space), itp)
152+
gpuitp = regridder._gpuitps[key]
104153

105154
return map(regridder.coordinates) do coord
106155
gpuitp(totuple(coord)...)

0 commit comments

Comments
 (0)