diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 9e55264b3..1b5ce5ccd 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -215,7 +215,7 @@ function _prepare_jacobian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] - seed_example = ntuple(b -> zero(x), Val(B)) + seed_example = ntuple(b -> basis(x), Val(B)) pushforward_prep = prepare_pushforward_nokwarg( strict, f_or_f!y..., backend, x, seed_example, contexts... ) @@ -246,7 +246,7 @@ function _prepare_jacobian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - seed_example = ntuple(b -> zero(y), Val(B)) + seed_example = ntuple(b -> basis(y), Val(B)) pullback_prep = prepare_pullback_nokwarg( strict, f_or_f!y..., backend, x, seed_example, contexts... ) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 15477bf9a..7ba3a83df 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -285,7 +285,11 @@ function _prepare_pullback_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) - dx = zero(x) + dx = if x isa Number + oneunit(x) + else + basis(x) + end pushforward_prep = prepare_pushforward_nokwarg( strict, f, backend, x, (dx,), contexts... ) @@ -303,7 +307,11 @@ function _prepare_pullback_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) - dx = zero(x) + dx = if x isa Number + oneunit(x) + else + basis(x) + end pushforward_prep = prepare_pushforward_nokwarg( strict, f!, y, backend, x, (dx,), contexts... ) diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index b029f66e8..c244e3289 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -290,7 +290,11 @@ function _prepare_pushforward_aux( ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) - dy = zero(y) + dy = if y isa Number + oneunit(y) + else + basis(y) + end pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end @@ -306,7 +310,7 @@ function _prepare_pushforward_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) - dy = zero(y) + dy = basis(y) pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index ffb6ba840..a815fb08c 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -84,6 +84,7 @@ struct HVPGradientHessianPrep{ BS<:BatchSizeSettings, S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, + SE<:NTuple, E2<:HVPPrep, E1<:GradientPrep, } <: HessianPrep{SIG} @@ -91,6 +92,7 @@ struct HVPGradientHessianPrep{ batch_size_settings::BS batched_seeds::S batched_results::R + seed_example::SE hvp_prep::E2 gradient_prep::E1 end @@ -119,10 +121,17 @@ function _prepare_hessian_aux( ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] - hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, batched_seeds[1], contexts...) + seed_example = ntuple(b -> basis(x), Val(B)) + hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, seed_example, contexts...) gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...) return HVPGradientHessianPrep( - _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep + _sig, + batch_size_settings, + batched_seeds, + batched_results, + seed_example, + hvp_prep, + gradient_prep, ) end @@ -150,11 +159,11 @@ function hessian( contexts::Vararg{Context,C}, ) where {F,SIG,B,aligned,C} check_prep(f, prep, backend, x, contexts...) - (; batch_size_settings, batched_seeds, hvp_prep) = prep + (; batch_size_settings, batched_seeds, seed_example, hvp_prep) = prep (; A, B_last) = batch_size_settings hvp_prep_same = prepare_hvp_same_point( - f, hvp_prep, backend, x, batched_seeds[1], contexts... + f, hvp_prep, backend, x, seed_example, contexts... ) hess = mapreduce(hcat, eachindex(batched_seeds)) do a @@ -178,11 +187,11 @@ function hessian!( contexts::Vararg{Context,C}, ) where {F,SIG,B,C} check_prep(f, prep, backend, x, contexts...) - (; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep + (; batch_size_settings, batched_seeds, batched_results, seed_example, hvp_prep) = prep (; N) = batch_size_settings hvp_prep_same = prepare_hvp_same_point( - f, hvp_prep, backend, x, batched_seeds[1], contexts... + f, hvp_prep, backend, x, seed_example, contexts... ) for a in eachindex(batched_seeds, batched_results) diff --git a/DifferentiationInterface/src/utils/basis.jl b/DifferentiationInterface/src/utils/basis.jl index 7fc38fcec..557708177 100644 --- a/DifferentiationInterface/src/utils/basis.jl +++ b/DifferentiationInterface/src/utils/basis.jl @@ -1,12 +1,6 @@ -""" - basis(a::AbstractArray, i) +pre_basis(a::AbstractArray{T}) where {T} = fill!(similar(a), zero(T)) -Construct the `i`-th standard basis array in the vector space of `a`. -""" -function basis(a::AbstractArray{T}, i) where {T} - b = similar(a) - fill!(b, zero(T)) - b[i] = oneunit(T) +function post_basis(b::AbstractArray, a::AbstractArray) if ismutable_array(a) return b else @@ -14,16 +8,32 @@ function basis(a::AbstractArray{T}, i) where {T} end end +""" + basis(a::AbstractArray, i) + +Construct the `i`-th standard basis array in the vector space of `a`. +""" +function basis(a::AbstractArray, i) + b = pre_basis(a) + b[i] = oneunit(eltype(b)) + return post_basis(b, a) +end + +# compatible with zero-length vectors +function basis(a::AbstractArray) + b = pre_basis(a) + return post_basis(b, a) +end + """ multibasis(a::AbstractArray, inds) Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`. """ -function multibasis(a::AbstractArray{T}, inds) where {T} - b = similar(a) - fill!(b, zero(T)) +function multibasis(a::AbstractArray, inds) + b = pre_basis(a) for i in inds - b[i] = oneunit(T) + b[i] = oneunit(eltype(b)) end - return ismutable_array(a) ? b : map(+, zero(a), b) + return post_basis(b, a) end diff --git a/DifferentiationInterface/test/Core/Internals/basis.jl b/DifferentiationInterface/test/Core/Internals/basis.jl index e79829990..4104bd0fb 100644 --- a/DifferentiationInterface/test/Core/Internals/basis.jl +++ b/DifferentiationInterface/test/Core/Internals/basis.jl @@ -26,4 +26,7 @@ using Dates t = [Time(1) - Time(0)] @test basis(t, 1) isa Vector{Nanosecond} + + @test basis([1, 2]) == [0, 0] + @test basis(Int[]) == Int[] end