From 7498ffed4142062cf370954e9ff0b3e330987c30 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 15 Aug 2025 00:33:44 +0100 Subject: [PATCH 1/2] Avoid batch size of 0 for empty inputs --- DifferentiationInterface/src/utils/batchsize.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl index 054d5c9b9..754d464ce 100644 --- a/DifferentiationInterface/src/utils/batchsize.jl +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -22,7 +22,7 @@ struct BatchSizeSettings{B,singlebatch,aligned} end function BatchSizeSettings{B,singlebatch,aligned}(N::Integer) where {B,singlebatch,aligned} - B > N && throw(ArgumentError("Batch size $B larger than input size $N")) + B > N > 0 && throw(ArgumentError("Batch size $B larger than input size $N")) A = div(N, B, RoundUp) B_last = N % B return BatchSizeSettings{B,singlebatch,aligned}(N, A, B_last) @@ -123,7 +123,9 @@ Reproduces the heuristic from ForwardDiff to minimize Source: https://github.com/JuliaDiff/ForwardDiff.jl/blob/ec74fbc32b10bbf60b3c527d8961666310733728/src/prelude.jl#L19-L29 """ function reasonable_batchsize(N::Integer, Bmax::Integer) - if N <= Bmax + if N == 0 + return 1 + elseif N <= Bmax return N else A = div(N, Bmax, RoundUp) From 8f93335fa941dda2ad4bd68aab1ecb50edfcf9e6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 16 Aug 2025 11:26:21 +0200 Subject: [PATCH 2/2] Add more support for zero batch size (incomplete Jacobian and Hessian) --- .../src/first_order/jacobian.jl | 28 +++++++++++++------ .../src/first_order/pullback.jl | 16 +++++++++-- .../src/first_order/pushforward.jl | 16 +++++++++-- .../src/utils/batchsize.jl | 16 ++++++----- .../test/Core/Internals/batchsize.jl | 3 ++ .../test/Core/ZeroBackends/test.jl | 13 +++++++++ 6 files changed, 73 insertions(+), 19 deletions(-) diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 5e6a07280..fa57dd2e4 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -153,12 +153,14 @@ struct PullbackJacobianPrep{ S<:AbstractVector{<:NTuple}, R<:AbstractVector{<:NTuple}, E<:PullbackPrep, + Y, } <: StandardJacobianPrep{SIG} _sig::Val{SIG} batch_size_settings::BS batched_seeds::S batched_results::R pullback_prep::E + y_example::Y end function prepare_jacobian_nokwarg( @@ -212,7 +214,7 @@ function _prepare_jacobian_aux( ] batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds] pushforward_prep = prepare_pushforward_nokwarg( - strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... + strict, f_or_f!y..., backend, x, ntuple(b -> zero(x), Val(B)), contexts... ) return PushforwardJacobianPrep( _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep @@ -237,10 +239,10 @@ function _prepare_jacobian_aux( ] batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds] pullback_prep = prepare_pullback_nokwarg( - strict, f_or_f!y..., backend, x, batched_seeds[1], contexts... + strict, f_or_f!y..., backend, x, ntuple(b -> zero(y), Val(B)), contexts... ) return PullbackJacobianPrep( - _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep + _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep, y ) end @@ -367,7 +369,7 @@ function _jacobian_aux( (; A, B_last) = batch_size_settings pushforward_prep_same = prepare_pushforward_same_point( - f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts... + f_or_f!y..., pushforward_prep, backend, x, ntuple(b -> zero(x), Val(B)), contexts... ) jac = mapreduce(hcat, eachindex(batched_seeds)) do a @@ -419,11 +421,16 @@ function _jacobian_aux( x, contexts::Vararg{Context,C}, ) where {FY,SIG,B,aligned,C} - (; batch_size_settings, batched_seeds, pullback_prep) = prep + (; batch_size_settings, batched_seeds, pullback_prep, y_example) = prep (; A, B_last) = batch_size_settings pullback_prep_same = prepare_pullback_same_point( - f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts... + f_or_f!y..., + pullback_prep, + backend, + x, + ntuple(b -> zero(y_example), Val(B)), + contexts..., ) jac = mapreduce(vcat, eachindex(batched_seeds)) do a @@ -487,11 +494,16 @@ function _jacobian_aux!( x, contexts::Vararg{Context,C}, ) where {FY,SIG,B,C} - (; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep + (; batch_size_settings, batched_seeds, batched_results, pullback_prep, y_example) = prep (; N) = batch_size_settings pullback_prep_same = prepare_pullback_same_point( - f_or_f!y..., pullback_prep, backend, x, batched_seeds[1], contexts... + f_or_f!y..., + pullback_prep, + backend, + x, + ntuple(b -> zero(y_example), Val(B)), + contexts..., ) for a in eachindex(batched_seeds, batched_results) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 6c6e6cc10..5512ebdc4 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -285,7 +285,13 @@ function _prepare_pullback_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f, backend, x, ty, contexts...; strict) - dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x))) + dx = if x isa Number + oneunit(x) + elseif isempty(x) + zero(x) + else + basis(x, first(CartesianIndices(x))) + end pushforward_prep = prepare_pushforward_nokwarg( strict, f, backend, x, (dx,), contexts... ) @@ -303,7 +309,13 @@ function _prepare_pullback_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, ty, contexts...; strict) - dx = x isa Number ? oneunit(x) : basis(x, first(CartesianIndices(x))) + dx = if x isa Number + oneunit(x) + elseif isempty(x) + zero(x) + else + basis(x, first(CartesianIndices(x))) + end pushforward_prep = prepare_pushforward_nokwarg( strict, f!, y, backend, x, (dx,), contexts... ) diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 46d249d67..6349c1f68 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -290,7 +290,13 @@ function _prepare_pushforward_aux( ) where {F,C} _sig = signature(f, backend, x, tx, contexts...; strict) y = f(x, map(unwrap, contexts)...) - dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y))) + dy = if y isa Number + oneunit(y) + elseif isempty(y) + zero(y) + else + basis(y, first(CartesianIndices(y))) + end pullback_prep = prepare_pullback_nokwarg(strict, f, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end @@ -306,7 +312,13 @@ function _prepare_pushforward_aux( contexts::Vararg{Context,C}; ) where {F,C} _sig = signature(f!, y, backend, x, tx, contexts...; strict) - dy = y isa Number ? oneunit(y) : basis(y, first(CartesianIndices(y))) + dy = if y isa Number + oneunit(y) + elseif isempty(y) + zero(y) + else + basis(y, first(CartesianIndices(y))) + end pullback_prep = prepare_pullback_nokwarg(strict, f!, y, backend, x, (dy,), contexts...) return PullbackPushforwardPrep(_sig, pullback_prep) end diff --git a/DifferentiationInterface/src/utils/batchsize.jl b/DifferentiationInterface/src/utils/batchsize.jl index 620c6ac71..510dd502b 100644 --- a/DifferentiationInterface/src/utils/batchsize.jl +++ b/DifferentiationInterface/src/utils/batchsize.jl @@ -23,21 +23,25 @@ end function BatchSizeSettings{B,singlebatch,aligned}(N::Integer) where {B,singlebatch,aligned} B > N > 0 && throw(ArgumentError("Batch size $B larger than input size $N")) - A = div(N, B, RoundUp) - B_last = N % B + if B == N == 0 + A = B_last = 0 + else + A = div(N, B, RoundUp) + B_last = N % B + end return BatchSizeSettings{B,singlebatch,aligned}(N, A, B_last) end function BatchSizeSettings{B}(::Val{N}) where {B,N} singlebatch = B == N - aligned = N % B == 0 + aligned = (B == N == 0) || (N % B == 0) return BatchSizeSettings{B,singlebatch,aligned}(N) end function BatchSizeSettings{B}(N::Integer) where {B} # type-unstable singlebatch = B == N - aligned = N % B == 0 + aligned = (B == N == 0) || (N % B == 0) return BatchSizeSettings{B,singlebatch,aligned}(N) end @@ -123,9 +127,7 @@ Reproduces the heuristic from ForwardDiff to minimize Source: https://github.com/JuliaDiff/ForwardDiff.jl/blob/ec74fbc32b10bbf60b3c527d8961666310733728/src/prelude.jl#L19-L29 """ function reasonable_batchsize(N::Integer, Bmax::Integer) - if N == 0 - return 1 - elseif N <= Bmax + if N <= Bmax return N else A = div(N, Bmax, RoundUp) diff --git a/DifferentiationInterface/test/Core/Internals/batchsize.jl b/DifferentiationInterface/test/Core/Internals/batchsize.jl index 29c17eb24..05f39d2ff 100644 --- a/DifferentiationInterface/test/Core/Internals/batchsize.jl +++ b/DifferentiationInterface/test/Core/Internals/batchsize.jl @@ -25,11 +25,14 @@ BSS = BatchSizeSettings end @testset "SimpleFiniteDiff (adaptive)" begin + @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(0))) isa BSS{0,true,true} @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(2))) isa BSS{2,true,true} @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(6))) isa BSS{6,true,true} @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(12))) isa BSS{12,true,true} @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(24))) isa BSS{12,false,true} @test (pick_batchsize(AutoSimpleFiniteDiff(), zeros(100))) isa BSS{12,false,false} + @test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(0)))) isa + BSS{0,true,true} @test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(2)))) isa BSS{2,true,true} @test (@inferred pick_batchsize(AutoSimpleFiniteDiff(), @SVector(zeros(6)))) isa diff --git a/DifferentiationInterface/test/Core/ZeroBackends/test.jl b/DifferentiationInterface/test/Core/ZeroBackends/test.jl index 6d56dabd4..235812183 100644 --- a/DifferentiationInterface/test/Core/ZeroBackends/test.jl +++ b/DifferentiationInterface/test/Core/ZeroBackends/test.jl @@ -1,6 +1,7 @@ using DifferentiationInterface using DifferentiationInterface: AutoZeroForward, AutoZeroReverse using DifferentiationInterfaceTest +using LinearAlgebra using ComponentArrays: ComponentArrays using JLArrays: JLArrays using SparseMatrixColorings @@ -50,3 +51,15 @@ end logging=LOGGING, ) end + +@testset "Empty arrays" begin + make_empty(t) = typeof(t)[] + make_empty!(y, t) = nothing + @test gradient(sum, AutoZeroForward(), Float64[]) == Float64[] + @test derivative(make_empty, AutoZeroReverse(), 1.0) == Float64[] + @test derivative(make_empty!, Float64[], AutoZeroReverse(), 1.0) == Float64[] + @test_broken jacobian(copy, AutoZeroForward(), Float64[]) == I(0) + @test_broken jacobian(copy, AutoZeroReverse(), Float64[]) == I(0) + @test_broken jacobian(copyto!, Float64[], AutoZeroForward(), Float64[]) == I(0) + @test_broken jacobian(copyto!, Float64[], AutoZeroReverse(), Float64[]) == I(0) +end