Skip to content

Commit 6620336

Browse files
authored
Merge pull request #75 from TUM-PIK-ESM/bg/gpu-fixes
Fixes for type stability/inference issues affecting GPU support
2 parents 66aa9ef + 23f883e commit 6620336

File tree

5 files changed

+30
-30
lines changed

5 files changed

+30
-30
lines changed

src/abstract_variables.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,10 @@ Helper method that selects only closure (auxiliary) variables declared on `obj`.
393393
@inline closure_variables(obj) = closure_variables(variables(obj))
394394
@inline function closure_variables(vars::Tuple)
395395
progvars = prognostic_variables(vars)
396-
closure_vars = mapreduce(var -> variables(var.closure), tuplejoin, progvars, init = ())
397-
return deduplicate_vars(closure_vars)
396+
all_closure_vars = fastmap(var -> variables(var.closure), progvars)
397+
return deduplicate_vars(tuplejoin(all_closure_vars...))
398398
end
399399

400-
401400
function Base.NamedTuple(vars::Tuple{Vararg{Union{AbstractVariable, Namespace}}})
402401
keys = map(varname, vars)
403402
return NamedTuple{keys}(vars)

src/processes/soil/hydrology/soil_hydraulic_properties.jl

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
Base type for unsaturated hydraulic conductivity parameterizations.
77
"""
8-
abstract type AbstractUnsatK end
8+
abstract type AbstractUnsatK{NF} end
99

1010
"""
1111
get_swrc(::AbstractUnsatK)
@@ -63,7 +63,7 @@ measurements of hydraulic properites are available.
6363
Properties:
6464
$TYPEDFIELDS
6565
"""
66-
@kwdef struct ConstantSoilHydraulics{NF, RC, UnsatK} <: AbstractSoilHydraulics{NF, RC, UnsatK}
66+
@kwdef struct ConstantSoilHydraulics{NF, RC, UnsatK <: AbstractUnsatK{NF}} <: AbstractSoilHydraulics{NF, RC, UnsatK}
6767
"Soil water retention curve"
6868
swrc::RC
6969

@@ -78,12 +78,6 @@ $TYPEDFIELDS
7878

7979
"Prescribed wilting point [-]"
8080
wilting_point::NF = 0.05
81-
82-
# TODO: Remove once FreezeCurves.jl allows for generic type bounds
83-
function ConstantSoilHydraulics(swrc::SWRC, unsat_hydraulic_cond::AbstractUnsatK, args::NF...) where {NF}
84-
adapted_swrc = adapt(NumberFormatAdaptor{NF}(), ustrip(swrc))
85-
return new{NF, typeof(adapted_swrc), typeof(unsat_hydraulic_cond)}(adapted_swrc, unsat_hydraulic_cond, args...)
86-
end
8781
end
8882

8983
function ConstantSoilHydraulics(
@@ -92,7 +86,8 @@ function ConstantSoilHydraulics(
9286
unsat_hydraulic_cond = UnsatKLinear(NF),
9387
kwargs...
9488
) where {NF}
95-
return ConstantSoilHydraulics(; swrc, unsat_hydraulic_cond, kwargs...)
89+
swrc = adapt(NumberFormatAdaptor{NF}(), ustrip(swrc))
90+
return ConstantSoilHydraulics{NF, typeof(swrc), typeof(unsat_hydraulic_cond)}(; swrc, unsat_hydraulic_cond, kwargs...)
9691
end
9792

9893
@inline saturated_hydraulic_conductivity(hydraulics::ConstantSoilHydraulics, args...) = hydraulics.sat_hydraulic_cond
@@ -110,7 +105,7 @@ and wilting point as a function of soil texture.
110105
Properties:
111106
$TYPEDFIELDS
112107
"""
113-
@kwdef struct SoilHydraulicsSURFEX{NF, RC, UnsatK} <: AbstractSoilHydraulics{NF, RC, UnsatK}
108+
@kwdef struct SoilHydraulicsSURFEX{NF, RC, UnsatK <: AbstractUnsatK{NF}} <: AbstractSoilHydraulics{NF, RC, UnsatK}
114109
"Soil water retention curve"
115110
swrc::RC
116111

@@ -128,12 +123,6 @@ $TYPEDFIELDS
128123

129124
"Exponent of field capacity adjustment due to clay content [-]"
130125
field_capacity_exp::NF = 0.35
131-
132-
# TODO: Remove once FreezeCurves.jl allows for generic type bounds
133-
function SoilHydraulicsSURFEX(swrc::SWRC, unsat_hydraulic_cond::AbstractUnsatK, args::NF...) where {NF}
134-
adapted_swrc = adapt(NumberFormatAdaptor{NF}(), ustrip(swrc))
135-
return new{NF, typeof(adapted_swrc), typeof(unsat_hydraulic_cond)}(adapted_swrc, unsat_hydraulic_cond, args...)
136-
end
137126
end
138127

139128
function SoilHydraulicsSURFEX(
@@ -142,7 +131,8 @@ function SoilHydraulicsSURFEX(
142131
unsat_hydraulic_cond = UnsatKLinear(NF),
143132
kwargs...
144133
) where {NF}
145-
return SoilHydraulicsSURFEX(; swrc, unsat_hydraulic_cond, kwargs...)
134+
swrc = adapt(NumberFormatAdaptor{NF}(), ustrip(swrc))
135+
return SoilHydraulicsSURFEX{NF, typeof(swrc), typeof(unsat_hydraulic_cond)}(; swrc, unsat_hydraulic_cond, kwargs...)
146136
end
147137

148138
# TODO: this is not quite correct, SURFEX uses a hydraulic conductivity function that decreases exponentially with depth
@@ -169,7 +159,7 @@ end
169159
Simple formulation of hydraulic conductivity as a linear function of the liquid water saturated fraction,
170160
i.e. `soil.water / (soil.water + soil.ice + soil.air)`.
171161
"""
172-
struct UnsatKLinear{NF} <: AbstractUnsatK end
162+
struct UnsatKLinear{NF} <: AbstractUnsatK{NF} end
173163

174164
UnsatKLinear(::Type{NF}) where {NF} = UnsatKLinear{NF}()
175165

@@ -194,7 +184,7 @@ volumetric fractions, assumed to include those of water, ice, and air.
194184
195185
See van Genuchten (1980) and Westermann et al. (2023).
196186
"""
197-
struct UnsatKVanGenuchten{NF} <: AbstractUnsatK
187+
struct UnsatKVanGenuchten{NF} <: AbstractUnsatK{NF}
198188
"Exponential scaling factor for ice impedance"
199189
impedance::NF
200190
end

src/processes/soil/stratigraphy/soil_volume.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ $TYPEDFIELDS
8888
organic::NF = zero(eltype(texture))
8989

9090
function MineralOrganic(texture::SoilTexture{NF}, organic::NF) where {NF}
91-
@assert 0 <= organic <= 1 "organic content must be between zero and one"
91+
@assert zero(NF) <= organic <= one(NF) "organic content must be between zero and one"
9292
return new{NF}(texture, organic)
9393
end
9494
end

src/state_variables.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,12 @@ end
211211
Retrieves all non-tendency `Field`s from `state` defined on the given `components`.
212212
"""
213213
@inline function get_fields(state, components...; except = (;))
214-
vars = mapreduce(tuplejoin, components, init = ()) do component
214+
component_vars = fastmap(components) do component
215215
allvars = variables(component)
216216
closurevars = closure_variables(component)
217217
tuplejoin(allvars, closurevars)
218218
end
219+
vars = tuplejoin(component_vars...)
219220
return ntdiff(get_fields(state, vars), except)
220221
end
221222

@@ -225,7 +226,9 @@ end
225226
Retrieves all `Field`s from `state` corresponding to prognostic variables defined on the given `components`.
226227
"""
227228
@inline function prognostic_fields(state, components...)
228-
return get_fields(state, mapreduce(prognostic_variables, tuplejoin, components))
229+
component_progvars = fastmap(prognostic_variables, components)
230+
progvars = tuplejoin(component_progvars...)
231+
return get_fields(state, progvars)
229232
end
230233

231234
"""
@@ -234,7 +237,9 @@ end
234237
Retrieves all `Field`s from `state` corresponding to tendencies defined on the given `components`.
235238
"""
236239
@inline function tendency_fields(state, components...)
237-
return get_fields(state.tendencies, mapreduce(prognostic_variables, tuplejoin, components))
240+
component_progvars = fastmap(prognostic_variables, components)
241+
progvars = tuplejoin(component_progvars...)
242+
return get_fields(state.tendencies, progvars)
238243
end
239244

240245
"""
@@ -243,7 +248,9 @@ end
243248
Retrieves all `Field`s from `state` corresponding to auxiliary variables defined on the given `components`.
244249
"""
245250
@inline function auxiliary_fields(state, components...)
246-
return get_fields(state, mapreduce(auxiliary_variables, tuplejoin, components))
251+
component_auxvars = fastmap(auxiliary_variables, components)
252+
auxvars = tuplejoin(component_auxvars...)
253+
return get_fields(state, auxvars)
247254
end
248255

249256
"""
@@ -252,7 +259,9 @@ end
252259
Retrieves all `Field`s from `state` corresponding to closure variables defined on the given `components`.
253260
"""
254261
@inline function closure_fields(state, components...)
255-
return get_fields(state, mapreduce(closure_variables, tuplejoin, components))
262+
component_closurevars = fastmap(closure_variables, components)
263+
closurevars = tuplejoin(component_closurevars...)
264+
return get_fields(state, closurevars)
256265
end
257266

258267
"""
@@ -261,7 +270,9 @@ end
261270
Retrieves all `Field`s from `state` corresponding to input variables defined on the given `components`.
262271
"""
263272
@inline function input_fields(state, components...)
264-
return get_fields(state, mapreduce(input_variables, tuplejoin, components))
273+
component_inputvars = fastmap(input_variables, components)
274+
inputvars = tuplejoin(component_inputvars...)
275+
return get_fields(state, inputvars)
265276
end
266277

267278
# Initialization of StateVariables from models and processes

src/utils/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ convert_dt(Δt::Period) = Second(Δt).value
2222
2323
Evaluates `x / (y + eps(NF))` if and only if `y != zero(y)`; returns `Inf` otherwise.
2424
"""
25-
safediv(x::NF, y::NF) where {NF} = ifelse(iszero(y), Inf, x / (y + eps(NF)))
25+
safediv(x::NF, y::NF) where {NF} = ifelse(iszero(y), NF(Inf), x / (y + eps(NF)))
2626

2727
# fastmap and fastiterate
2828

0 commit comments

Comments
 (0)