diff --git a/src/destructure.jl b/src/destructure.jl index 2b91983d..d8eb0378 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -130,13 +130,20 @@ function _grad!(x, dx, off, flat::AbstractVector) foreach((xᵢ, dxᵢ, oᵢ) -> _grad!(xᵢ, dxᵢ, oᵢ, flat), x′, dx′, off′) flat end -function _grad!(x, dx, off::Integer, flat::AbstractVector) +function _grad!(x::T, dx::T, off::Integer, flat::AbstractVector) where T @views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes flat end _grad!(x, dx::Zero, off, flat::AbstractVector) = dx _grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = dx # ambiguity +function _grad!(x::T, dx::S, off::Integer, flat::AbstractVector) where {T, S} + flat = similar(dx, length(flat)) + @views flat[off .+ (1:length(x))] .+= vec(dx) # must visit all tied nodes + flat +end + + # These are only needed for 2nd derivatives: function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) @warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3