diff --git a/src/eval.jl b/src/eval.jl index d66ac55..68eadd1 100644 --- a/src/eval.jl +++ b/src/eval.jl @@ -48,15 +48,18 @@ import ChainRulesCore function ChainRulesCore.rrule(ev::Eval, args...) Z = ev.fwd(args...) - Z, function tullio_back(Δ) - isnothing(ev.rev) && error("no gradient definition here!") + function tullio_back(Δ) dxs = map(ev.rev(Δ, Z, args...)) do dx dx === nothing ? ChainRulesCore.ZeroTangent() : dx end - tuple(ChainRulesCore.ZeroTangent(), dxs...) + return (ChainRulesCore.NoTangent(), dxs...) end + return Z, tullio_back end +# without gradient definition we let the AD system differentiate the function +ChainRulesCore.@opt_out ChainRulesCore.rrule(ev::Eval{<:Any,Nothing}, args...) + @init @require FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" begin using .FillArrays: Fill # used by Zygote Tullio.promote_storage(::Type{T}, ::Type{F}) where {T, F<:Fill} = T