Skip to content

Commit d76db32

Browse files
penelopeysmgdalle
andauthored
fix: improve support for empty inputs (still not guaranteed) (#835)
* Avoid batch size of 0 for empty inputs * Add more support for zero batch size (incomplete Jacobian and Hessian) * test: add proper empty tests * chore: Bump DI, update changelogs * fix: fix type stability * fix: use zero as example --------- Co-authored-by: Guillaume Dalle <[email protected]>
1 parent 4a59a3a commit d76db32

File tree

17 files changed

+107
-30
lines changed

17 files changed

+107
-30
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: true # TODO: toggle
28+
fail-fast: false # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"

DifferentiationInterface/CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.6...main)
8+
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.7...main)
9+
10+
## [0.7.7](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.6...DifferentiationInterface-v0.7.7)
11+
12+
- Improve support for empty inputs (still not guaranteed) ([#835](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/835))
913

1014
## [0.7.6](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...DifferentiationInterface-v0.7.6)
1115

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.6"
4+
version = "0.7.7"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,14 @@ struct PushforwardJacobianPrep{
138138
BS<:BatchSizeSettings,
139139
S<:AbstractVector{<:NTuple},
140140
R<:AbstractVector{<:NTuple},
141+
SE<:NTuple,
141142
E<:PushforwardPrep,
142143
} <: StandardJacobianPrep{SIG}
143144
_sig::Val{SIG}
144145
batch_size_settings::BS
145146
batched_seeds::S
146147
batched_results::R
148+
seed_example::SE
147149
pushforward_prep::E
148150
end
149151

@@ -152,12 +154,14 @@ struct PullbackJacobianPrep{
152154
BS<:BatchSizeSettings,
153155
S<:AbstractVector{<:NTuple},
154156
R<:AbstractVector{<:NTuple},
157+
SE<:NTuple,
155158
E<:PullbackPrep,
156159
} <: StandardJacobianPrep{SIG}
157160
_sig::Val{SIG}
158161
batch_size_settings::BS
159162
batched_seeds::S
160163
batched_results::R
164+
seed_example::SE
161165
pullback_prep::E
162166
end
163167

@@ -211,11 +215,17 @@ function _prepare_jacobian_aux(
211215
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
212216
]
213217
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
218+
seed_example = ntuple(b -> zero(x), Val(B))
214219
pushforward_prep = prepare_pushforward_nokwarg(
215-
strict, f_or_f!y..., backend, x, batched_seeds[1], contexts...
220+
strict, f_or_f!y..., backend, x, seed_example, contexts...
216221
)
217222
return PushforwardJacobianPrep(
218-
_sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
223+
_sig,
224+
batch_size_settings,
225+
batched_seeds,
226+
batched_results,
227+
seed_example,
228+
pushforward_prep,
219229
)
220230
end
221231

@@ -236,11 +246,17 @@ function _prepare_jacobian_aux(
236246
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
237247
]
238248
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
249+
seed_example = ntuple(b -> zero(y), Val(B))
239250
pullback_prep = prepare_pullback_nokwarg(
240-
strict, f_or_f!y..., backend, x, batched_seeds[1], contexts...
251+
strict, f_or_f!y..., backend, x, seed_example, contexts...
241252
)
242253
return PullbackJacobianPrep(
243-
_sig, batch_size_settings, batched_seeds, batched_results, pullback_prep
254+
_sig,
255+
batch_size_settings,
256+
batched_seeds,
257+
batched_results,
258+
seed_example,
259+
pullback_prep,
244260
)
245261
end
246262

@@ -363,11 +379,11 @@ function _jacobian_aux(
363379
x,
364380
contexts::Vararg{Context,C},
365381
) where {FY,SIG,B,aligned,C}
366-
(; batch_size_settings, batched_seeds, pushforward_prep) = prep
382+
(; batch_size_settings, batched_seeds, seed_example, pushforward_prep) = prep
367383
(; A, B_last) = batch_size_settings
368384

369385
pushforward_prep_same = prepare_pushforward_same_point(
370-
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
386+
f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts...
371387
)
372388

373389
jac = mapreduce(hcat, eachindex(batched_seeds)) do a
@@ -419,11 +435,11 @@ function _jacobian_aux(
419435
x,
420436
contexts::Vararg{Context,C},
421437
) where {FY,SIG,B,aligned,C}
422-
(; batch_size_settings, batched_seeds, pullback_prep) = prep
438+
(; batch_size_settings, batched_seeds, seed_example, pullback_prep) = prep
423439
(; A, B_last) = batch_size_settings
424440

425441
pullback_prep_same = prepare_pullback_same_point(
426-
f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts...
442+
f_or_f!y..., pullback_prep, backend, x, seed_example, contexts...
427443
)
428444

429445
jac = mapreduce(vcat, eachindex(batched_seeds)) do a
@@ -451,11 +467,13 @@ function _jacobian_aux!(
451467
x,
452468
contexts::Vararg{Context,C},
453469
) where {FY,SIG,B,C}
454-
(; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep
470+
(;
471+
batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep
472+
) = prep
455473
(; N) = batch_size_settings
456474

457475
pushforward_prep_same = prepare_pushforward_same_point(
458-
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
476+
f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts...
459477
)
460478

461479
for a in eachindex(batched_seeds, batched_results)
@@ -487,11 +505,12 @@ function _jacobian_aux!(
487505
x,
488506
contexts::Vararg{Context,C},
489507
) where {FY,SIG,B,C}
490-
(; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep
508+
(; batch_size_settings, batched_seeds, batched_results, seed_example, pullback_prep) =
509+
prep
491510
(; N) = batch_size_settings
492511

493512
pullback_prep_same = prepare_pullback_same_point(
494-
f_or_f!y..., pullback_prep, backend, x, batched_seeds[1], contexts...
513+
f_or_f!y..., pullback_prep, backend, x, seed_example, contexts...
495514
)
496515

497516
for a in eachindex(batched_seeds, batched_results)

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ function _prepare_pullback_aux(
285285
contexts::Vararg{Context,C};
286286
) where {F,C}
287287
_sig = signature(f, backend, x, ty, contexts...; strict)
288-
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
288+
dx = zero(x)
289289
pushforward_prep = prepare_pushforward_nokwarg(
290290
strict, f, backend, x, (dx,), contexts...
291291
)
@@ -303,7 +303,7 @@ function _prepare_pullback_aux(
303303
contexts::Vararg{Context,C};
304304
) where {F,C}
305305
_sig = signature(f!, y, backend, x, ty, contexts...; strict)
306-
dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x)))
306+
dx = zero(x)
307307
pushforward_prep = prepare_pushforward_nokwarg(
308308
strict, f!, y, backend, x, (dx,), contexts...
309309
)

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function _prepare_pushforward_aux(
290290
) where {F,C}
291291
_sig = signature(f, backend, x, tx, contexts...; strict)
292292
y = f(x, map(unwrap, contexts)...)
293-
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
293+
dy = zero(y)
294294
pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...)
295295
return PullbackPushforwardPrep(_sig, pullback_prep)
296296
end
@@ -306,7 +306,7 @@ function _prepare_pushforward_aux(
306306
contexts::Vararg{Context,C};
307307
) where {F,C}
308308
_sig = signature(f!, y, backend, x, tx, contexts...; strict)
309-
dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y)))
309+
dy = zero(y)
310310
pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...)
311311
return PullbackPushforwardPrep(_sig, pullback_prep)
312312
end

DifferentiationInterface/src/utils/batchsize.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Configuration for the batch size deduced from a backend and a sample array of le
66
# Type parameters
77
88
- `B::Int`: batch size
9-
- `singlebatch::Bool`: whether `B == N` (`B > N` is not allowed)
9+
- `singlebatch::Bool`: whether `B == N` (`B > N` is only allowed when `N == 0`)
1010
- `aligned::Bool`: whether `N % B == 0`
1111
1212
# Fields
@@ -22,22 +22,26 @@ struct BatchSizeSettings{B,singlebatch,aligned}
2222
end
2323

2424
function BatchSizeSettings{B,singlebatch,aligned}(N::Integer) where {B,singlebatch,aligned}
25-
B > N && throw(ArgumentError("Batch size $B larger than input size $N"))
26-
A = div(N, B, RoundUp)
27-
B_last = N % B
25+
B > N > 0 && throw(ArgumentError("Batch size $B larger than input size $N"))
26+
if B == N == 0
27+
A = B_last = 0
28+
else
29+
A = div(N, B, RoundUp)
30+
B_last = N % B
31+
end
2832
return BatchSizeSettings{B,singlebatch,aligned}(N, A, B_last)
2933
end
3034

3135
function BatchSizeSettings{B}(::Val{N}) where {B,N}
3236
singlebatch = B == N
33-
aligned = N % B == 0
37+
aligned = (B == N == 0) || (N % B == 0)
3438
return BatchSizeSettings{B,singlebatch,aligned}(N)
3539
end
3640

3741
function BatchSizeSettings{B}(N::Integer) where {B}
3842
# type-unstable
3943
singlebatch = B == N
40-
aligned = N % B == 0
44+
aligned = (B == N == 0) || (N % B == 0)
4145
return BatchSizeSettings{B,singlebatch,aligned}(N)
4246
end
4347

@@ -123,7 +127,9 @@ Reproduces the heuristic from ForwardDiff to minimize
123127
Source: https://github.com/JuliaDiff/ForwardDiff.jl/blob/ec74fbc32b10bbf60b3c527d8961666310733728/src/prelude.jl#L19-L29
124128
"""
125129
function reasonable_batchsize(N::Integer, Bmax::Integer)
126-
if N <= Bmax
130+
if N == 0
131+
return 1
132+
elseif N <= Bmax
127133
return N
128134
else
129135
A = div(N, Bmax, RoundUp)

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,11 @@ end
201201
@test occursin("DifferentiationInterface", msg)
202202
end
203203
end
204+
205+
@testset "Empty arrays" begin
206+
test_differentiation(
207+
[AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)],
208+
empty_scenarios();
209+
excluded=[:jacobian],
210+
)
211+
end;

DifferentiationInterface/test/Core/Internals/batchsize.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ using Test
1313
BSS = BatchSizeSettings
1414

1515
@testset "Default" begin
16+
@test (@inferred pick_batchsize(AutoZygote(), zeros(0))) isa BSS{1,false,true}
1617
@test (@inferred pick_batchsize(AutoZygote(), zeros(2))) isa BSS{1,false,true}
1718
@test (@inferred pick_batchsize(AutoZygote(), zeros(100))) isa BSS{1,false,true}
1819
@test_throws ArgumentError pick_batchsize(AutoSparse(AutoZygote()), zeros(2))
@@ -25,11 +26,14 @@ BSS = BatchSizeSettings
2526
end
2627

2728
@testset "SimpleFiniteDiff (adaptive)" begin
29+
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(0))) isa BSS{1,false,true}
2830
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(2))) isa BSS{2,true,true}
2931
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(6))) isa BSS{6,true,true}
3032
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(12))) isa BSS{12,true,true}
3133
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(24))) isa BSS{12,false,true}
3234
@test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(100))) isa BSS{12,false,false}
35+
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(0)))) isa
36+
BSS{0,true,true}
3337
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(2)))) isa
3438
BSS{2,true,true}
3539
@test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(6)))) isa

DifferentiationInterface/test/Core/ZeroBackends/test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using DifferentiationInterface
22
using DifferentiationInterface: AutoZeroForward, AutoZeroReverse
33
using DifferentiationInterfaceTest
4+
using LinearAlgebra
45
using ComponentArrays: ComponentArrays
56
using JLArrays: JLArrays
67
using SparseMatrixColorings
@@ -50,3 +51,9 @@ end
5051
logging=LOGGING,
5152
)
5253
end
54+
55+
@testset "Empty arrays" begin
56+
test_differentiation(
57+
[AutoZeroForward(), AutoZeroReverse()], empty_scenarios(); excluded=[:jacobian]
58+
)
59+
end;

0 commit comments

Comments
 (0)