-
Notifications
You must be signed in to change notification settings - Fork 45
Description
To demonstrate basic differentiability we aim to do a sensitivity analysis, e.g. by computing the derivative
Currently a simplified example of this is failing. I did try several somewhat different version of writing the vorticity_at_gridpoint!
and all of them failed (some with a different error). As @bgroenks96 pointed out, we could already now do exactly this kind of sensitivity analysis by using a one-hot-array as the seed when directly differentiating timestep!
(which works). However, when an error already occurs here, it'll also cause other problems in other things we want to look at.
using SpeedyWeather, Enzyme
# first we initialize a model
# for now, we pick our least complex model, to make everything run faster
arch = SpeedyWeather.CPU()
spectral_grid = SpectralGrid(trunc=8, nlayers=1, architecture=arch)
model = BarotropicModel(spectral_grid)
simulation = initialize!(model)
initialize!(simulation)
run!(simulation, period=Day(3)) # a bit of spin-up
const i_point = 12 # pick this kind of random point
const dt = model.time_stepping.Δt
# in this example, to demonstrate the error, we just take some point of the vorticity field
# that's actually spectral field, which doesn't make too much sense for a sensitivity analysis,
# but this is just to demonstrate the error with Enzyme
function vorticity_at_gridpoint!(progn, diagn, model)
# just do a single timestep here for this example, later this will be replaced with run!(... , period=Day(N_days))
SpeedyWeather.timestep!(progn, diagn, dt, model)
return abs(progn.vor[i_point, 1, end])
end
progn = simulation.prognostic_variables
dprogn = zero(progn)
diagn = simulation.diagnostic_variables
ddiag = make_zero(diagn)
# just to trigger some compilation
vorticity_at_gridpoint!(progn, diagn, model)
autodiff(Reverse, vorticity_at_gridpoint!, Active, Duplicated(progn, dprogn), Duplicated(diagn, ddiag), Const(model))
# catch the error to display it
try
autodiff(Reverse, vorticity_at_gridpoint!, Active, Duplicated(progn, dprogn), Duplicated(diagn, ddiag), Const(model))
catch err
println("Enzyme error caught")
end
# contrary, this works fine:
autodiff(Reverse, SpeedyWeather.timestep!, Const, Duplicated(progn, dprogn), Duplicated(diagn, ddiag), Const(dt), Const(model))
Yields currently a method error. Last time I tried this two months ago, it was something else. Now it looks like a rule is missing for something? It's the activity analysis determining Const activity for the output of our fourier transform, where our rule is just for Duplicated.
ERROR: MethodError: no method matching augmented_primal(::EnzymeCore.EnzymeRules.RevConfigWidth{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Const{…}, ::Duplicated{…}, ::Const{…})
Closest candidates are:
augmented_primal(::EnzymeCore.EnzymeRules.RevConfig, ::Const{typeof(Enzyme.pmap)}, ::Type{Const{Nothing}}, ::BodyTy, ::Any, ::Annotation...) where {BodyTy, N}
@ Enzyme ~/.julia/packages/Enzyme/12QGc/src/internal_rules.jl:337
augmented_primal(::EnzymeCore.EnzymeRules.RevConfig, ::Const{<:KernelAbstractions.Kernel}, ::Type{Const{Nothing}}, ::Any...; ndrange, workgroupsize) where N
@ EnzymeExt ~/.julia/packages/KernelAbstractions/lGrz7/ext/EnzymeCore08Ext.jl:214
augmented_primal(::Any, ::Const{typeof(QuadGK.quadgk)}, ::Type{RT}, ::Any, ::Annotation{T}...; kws...) where {RT, T}
@ QuadGKEnzymeExt ~/.julia/packages/QuadGK/7rND3/ext/QuadGKEnzymeExt.jl:6
...
Some type information was truncated. Use `show(err)` to see complete types.
The code is on the current main
with Julia v1.10 (there are currently still problems with Julia v1.11).