Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ function _prepare_jacobian_aux(
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
]
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
seed_example = ntuple(b -> zero(x), Val(B))
seed_example = ntuple(b -> basis(x), Val(B))
pushforward_prep = prepare_pushforward_nokwarg(
strict, f_or_f!y..., backend, x, seed_example, contexts...
)
Expand Down Expand Up @@ -246,7 +246,7 @@ function _prepare_jacobian_aux(
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
]
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
seed_example = ntuple(b -> zero(y), Val(B))
seed_example = ntuple(b -> basis(y), Val(B))
pullback_prep = prepare_pullback_nokwarg(
strict, f_or_f!y..., backend, x, seed_example, contexts...
)
Expand Down
12 changes: 10 additions & 2 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ function _prepare_pullback_aux(
contexts::Vararg{Context,C};
) where {F,C}
_sig = signature(f, backend, x, ty, contexts...; strict)
dx = zero(x)
dx = if x isa Number
oneunit(x)
else
basis(x)
end
pushforward_prep = prepare_pushforward_nokwarg(
strict, f, backend, x, (dx,), contexts...
)
Expand All @@ -303,7 +307,11 @@ function _prepare_pullback_aux(
contexts::Vararg{Context,C};
) where {F,C}
_sig = signature(f!, y, backend, x, ty, contexts...; strict)
dx = zero(x)
dx = if x isa Number
oneunit(x)
else
basis(x)
end
pushforward_prep = prepare_pushforward_nokwarg(
strict, f!, y, backend, x, (dx,), contexts...
)
Expand Down
8 changes: 6 additions & 2 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,11 @@ function _prepare_pushforward_aux(
) where {F,C}
_sig = signature(f, backend, x, tx, contexts...; strict)
y = f(x, map(unwrap, contexts)...)
dy = zero(y)
dy = if y isa Number
oneunit(y)
else
basis(y)
end
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
return PullbackPushforwardPrep(_sig, pullback_prep)
end
Expand All @@ -306,7 +310,7 @@ function _prepare_pushforward_aux(
contexts::Vararg{Context,C};
) where {F,C}
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
dy = zero(y)
dy = basis(y)
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
return PullbackPushforwardPrep(_sig, pullback_prep)
end
Expand Down
21 changes: 15 additions & 6 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,15 @@ struct HVPGradientHessianPrep{
BS<:BatchSizeSettings,
S<:AbstractVector{<:NTuple},
R<:AbstractVector{<:NTuple},
SE<:NTuple,
E2<:HVPPrep,
E1<:GradientPrep,
} <: HessianPrep{SIG}
_sig::Val{SIG}
batch_size_settings::BS
batched_seeds::S
batched_results::R
seed_example::SE
hvp_prep::E2
gradient_prep::E1
end
Expand Down Expand Up @@ -119,10 +121,17 @@ function _prepare_hessian_aux(
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
]
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, batched_seeds[1], contexts...)
seed_example = ntuple(b -> basis(x), Val(B))
hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, seed_example, contexts...)
gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...)
return HVPGradientHessianPrep(
_sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep
_sig,
batch_size_settings,
batched_seeds,
batched_results,
seed_example,
hvp_prep,
gradient_prep,
)
end

Expand Down Expand Up @@ -150,11 +159,11 @@ function hessian(
contexts::Vararg{Context,C},
) where {F,SIG,B,aligned,C}
check_prep(f, prep, backend, x, contexts...)
(; batch_size_settings, batched_seeds, hvp_prep) = prep
(; batch_size_settings, batched_seeds, seed_example, hvp_prep) = prep
(; A, B_last) = batch_size_settings

hvp_prep_same = prepare_hvp_same_point(
f, hvp_prep, backend, x, batched_seeds[1], contexts...
f, hvp_prep, backend, x, seed_example, contexts...
)

hess = mapreduce(hcat, eachindex(batched_seeds)) do a
Expand All @@ -178,11 +187,11 @@ function hessian!(
contexts::Vararg{Context,C},
) where {F,SIG,B,C}
check_prep(f, prep, backend, x, contexts...)
(; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep
(; batch_size_settings, batched_seeds, batched_results, seed_example, hvp_prep) = prep
(; N) = batch_size_settings

hvp_prep_same = prepare_hvp_same_point(
f, hvp_prep, backend, x, batched_seeds[1], contexts...
f, hvp_prep, backend, x, seed_example, contexts...
)

for a in eachindex(batched_seeds, batched_results)
Expand Down
36 changes: 23 additions & 13 deletions DifferentiationInterface/src/utils/basis.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
"""
basis(a::AbstractArray, i)
pre_basis(a::AbstractArray{T}) where {T} = fill!(similar(a), zero(T))

Construct the `i`-th standard basis array in the vector space of `a`.
"""
function basis(a::AbstractArray{T}, i) where {T}
b = similar(a)
fill!(b, zero(T))
b[i] = oneunit(T)
function post_basis(b::AbstractArray, a::AbstractArray)
if ismutable_array(a)
return b
else
return map(+, zero(a), b)
end
end

"""
basis(a::AbstractArray, i)

Construct the `i`-th standard basis array in the vector space of `a`.
"""
function basis(a::AbstractArray, i)
b = pre_basis(a)
b[i] = oneunit(eltype(b))
return post_basis(b, a)
end

# compatible with zero-length vectors
function basis(a::AbstractArray)
b = pre_basis(a)
return post_basis(b, a)
end

"""
multibasis(a::AbstractArray, inds)

Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
"""
function multibasis(a::AbstractArray{T}, inds) where {T}
b = similar(a)
fill!(b, zero(T))
function multibasis(a::AbstractArray, inds)
b = pre_basis(a)
for i in inds
b[i] = oneunit(T)
b[i] = oneunit(eltype(b))
end
return ismutable_array(a) ? b : map(+, zero(a), b)
return post_basis(b, a)
end
3 changes: 3 additions & 0 deletions DifferentiationInterface/test/Core/Internals/basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ using Dates

t = [Time(1) - Time(0)]
@test basis(t, 1) isa Vector{Nanosecond}

@test basis([1, 2]) == [0, 0]
@test basis(Int[]) == Int[]
end
Loading