Skip to content

Commit 02e2e56

Browse files
authored
fix: add basis handling for empty arrays (#843)
* fix: add basis handling for empty arrays * fix: forgotten field * fix: remove dead code
1 parent d76db32 commit 02e2e56

File tree

6 files changed

+59
-25
lines changed

6 files changed

+59
-25
lines changed

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ function _prepare_jacobian_aux(
215215
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
216216
]
217217
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
218-
seed_example = ntuple(b -> zero(x), Val(B))
218+
seed_example = ntuple(b -> basis(x), Val(B))
219219
pushforward_prep = prepare_pushforward_nokwarg(
220220
strict, f_or_f!y..., backend, x, seed_example, contexts...
221221
)
@@ -246,7 +246,7 @@ function _prepare_jacobian_aux(
246246
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
247247
]
248248
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
249-
seed_example = ntuple(b -> zero(y), Val(B))
249+
seed_example = ntuple(b -> basis(y), Val(B))
250250
pullback_prep = prepare_pullback_nokwarg(
251251
strict, f_or_f!y..., backend, x, seed_example, contexts...
252252
)

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,11 @@ function _prepare_pullback_aux(
285285
contexts::Vararg{Context,C};
286286
) where {F,C}
287287
_sig = signature(f, backend, x, ty, contexts...; strict)
288-
dx = zero(x)
288+
dx = if x isa Number
289+
oneunit(x)
290+
else
291+
basis(x)
292+
end
289293
pushforward_prep = prepare_pushforward_nokwarg(
290294
strict, f, backend, x, (dx,), contexts...
291295
)
@@ -303,7 +307,11 @@ function _prepare_pullback_aux(
303307
contexts::Vararg{Context,C};
304308
) where {F,C}
305309
_sig = signature(f!, y, backend, x, ty, contexts...; strict)
306-
dx = zero(x)
310+
dx = if x isa Number
311+
oneunit(x)
312+
else
313+
basis(x)
314+
end
307315
pushforward_prep = prepare_pushforward_nokwarg(
308316
strict, f!, y, backend, x, (dx,), contexts...
309317
)

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,11 @@ function _prepare_pushforward_aux(
290290
) where {F,C}
291291
_sig = signature(f, backend, x, tx, contexts...; strict)
292292
y = f(x, map(unwrap, contexts)...)
293-
dy = zero(y)
293+
dy = if y isa Number
294+
oneunit(y)
295+
else
296+
basis(y)
297+
end
294298
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
295299
return PullbackPushforwardPrep(_sig, pullback_prep)
296300
end
@@ -306,7 +310,7 @@ function _prepare_pushforward_aux(
306310
contexts::Vararg{Context,C};
307311
) where {F,C}
308312
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
309-
dy = zero(y)
313+
dy = basis(y)
310314
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
311315
return PullbackPushforwardPrep(_sig, pullback_prep)
312316
end

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ struct HVPGradientHessianPrep{
8484
BS<:BatchSizeSettings,
8585
S<:AbstractVector{<:NTuple},
8686
R<:AbstractVector{<:NTuple},
87+
SE<:NTuple,
8788
E2<:HVPPrep,
8889
E1<:GradientPrep,
8990
} <: HessianPrep{SIG}
9091
_sig::Val{SIG}
9192
batch_size_settings::BS
9293
batched_seeds::S
9394
batched_results::R
95+
seed_example::SE
9496
hvp_prep::E2
9597
gradient_prep::E1
9698
end
@@ -119,10 +121,17 @@ function _prepare_hessian_aux(
119121
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
120122
]
121123
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
122-
hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, batched_seeds[1], contexts...)
124+
seed_example = ntuple(b -> basis(x), Val(B))
125+
hvp_prep = prepare_hvp_nokwarg(strict, f, backend, x, seed_example, contexts...)
123126
gradient_prep = prepare_gradient_nokwarg(strict, f, inner(backend), x, contexts...)
124127
return HVPGradientHessianPrep(
125-
_sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep
128+
_sig,
129+
batch_size_settings,
130+
batched_seeds,
131+
batched_results,
132+
seed_example,
133+
hvp_prep,
134+
gradient_prep,
126135
)
127136
end
128137

@@ -150,11 +159,11 @@ function hessian(
150159
contexts::Vararg{Context,C},
151160
) where {F,SIG,B,aligned,C}
152161
check_prep(f, prep, backend, x, contexts...)
153-
(; batch_size_settings, batched_seeds, hvp_prep) = prep
162+
(; batch_size_settings, batched_seeds, seed_example, hvp_prep) = prep
154163
(; A, B_last) = batch_size_settings
155164

156165
hvp_prep_same = prepare_hvp_same_point(
157-
f, hvp_prep, backend, x, batched_seeds[1], contexts...
166+
f, hvp_prep, backend, x, seed_example, contexts...
158167
)
159168

160169
hess = mapreduce(hcat, eachindex(batched_seeds)) do a
@@ -178,11 +187,11 @@ function hessian!(
178187
contexts::Vararg{Context,C},
179188
) where {F,SIG,B,C}
180189
check_prep(f, prep, backend, x, contexts...)
181-
(; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep
190+
(; batch_size_settings, batched_seeds, batched_results, seed_example, hvp_prep) = prep
182191
(; N) = batch_size_settings
183192

184193
hvp_prep_same = prepare_hvp_same_point(
185-
f, hvp_prep, backend, x, batched_seeds[1], contexts...
194+
f, hvp_prep, backend, x, seed_example, contexts...
186195
)
187196

188197
for a in eachindex(batched_seeds, batched_results)
Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,39 @@
1-
"""
2-
basis(a::AbstractArray, i)
1+
pre_basis(a::AbstractArray{T}) where {T} = fill!(similar(a), zero(T))
32

4-
Construct the `i`-th standard basis array in the vector space of `a`.
5-
"""
6-
function basis(a::AbstractArray{T}, i) where {T}
7-
b = similar(a)
8-
fill!(b, zero(T))
9-
b[i] = oneunit(T)
3+
function post_basis(b::AbstractArray, a::AbstractArray)
104
if ismutable_array(a)
115
return b
126
else
137
return map(+, zero(a), b)
148
end
159
end
1610

11+
"""
12+
basis(a::AbstractArray, i)
13+
14+
Construct the `i`-th standard basis array in the vector space of `a`.
15+
"""
16+
function basis(a::AbstractArray, i)
17+
b = pre_basis(a)
18+
b[i] = oneunit(eltype(b))
19+
return post_basis(b, a)
20+
end
21+
22+
# compatible with zero-length vectors
23+
function basis(a::AbstractArray)
24+
b = pre_basis(a)
25+
return post_basis(b, a)
26+
end
27+
1728
"""
1829
multibasis(a::AbstractArray, inds)
1930
2031
Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
2132
"""
22-
function multibasis(a::AbstractArray{T}, inds) where {T}
23-
b = similar(a)
24-
fill!(b, zero(T))
33+
function multibasis(a::AbstractArray, inds)
34+
b = pre_basis(a)
2535
for i in inds
26-
b[i] = oneunit(T)
36+
b[i] = oneunit(eltype(b))
2737
end
28-
return ismutable_array(a) ? b : map(+, zero(a), b)
38+
return post_basis(b, a)
2939
end

DifferentiationInterface/test/Core/Internals/basis.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,7 @@ using Dates
2626

2727
t = [Time(1) - Time(0)]
2828
@test basis(t, 1) isa Vector{Nanosecond}
29+
30+
@test basis([1, 2]) == [0, 0]
31+
@test basis(Int[]) == Int[]
2932
end

0 commit comments

Comments
 (0)