Skip to content

Commit 5dfd7ad

Browse files
authored
test: allow passing other scenarios to adjust preparation from (#701)
* test: allow passing other scenarios to adjust preparation from * Typo * copy * Switch to other scenario bundled with initial scenario * No fail fast * Typo * better batchsize handling * Fix * fiox * fix * fix scenarios * Smaller scens * Fix tests * Fix tracker * Use with_contexts for Tracker * fixes * Revert Tracker changes
1 parent 9773294 commit 5dfd7ad

File tree

25 files changed

+617
-253
lines changed

25 files changed

+617
-253
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: true # TODO: toggle
28+
fail-fast: false # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.34"
4+
version = "0.6.35"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
2-
function DI.BatchSizeSettings(::AutoEnzyme, N::Integer)
2+
function DI.pick_batchsize(::AutoEnzyme, N::Integer)
33
B = DI.reasonable_batchsize(N, 16)
44
return DI.BatchSizeSettings{B}(N)
55
end

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,23 @@ function DI.prepare_derivative(
199199
return ForwardDiffTwoArgDerivativePrep(config)
200200
end
201201

202+
function DI.prepare!_derivative(
203+
f!::F,
204+
y,
205+
old_prep::ForwardDiffTwoArgDerivativePrep,
206+
backend::AutoForwardDiff,
207+
x,
208+
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
209+
) where {F,C}
210+
if y isa Vector
211+
(; config) = old_prep
212+
resize!(config.duals, length(y))
213+
return old_prep
214+
else
215+
return DI.prepare_derivative(f!, y, backend, x, contexts...)
216+
end
217+
end
218+
202219
function DI.value_and_derivative(
203220
f!::F,
204221
y,
@@ -352,6 +369,25 @@ function DI.prepare_jacobian(
352369
return ForwardDiffTwoArgJacobianPrep(config)
353370
end
354371

372+
function DI.prepare!_jacobian(
373+
f!::F,
374+
y,
375+
old_prep::ForwardDiffTwoArgJacobianPrep,
376+
backend::AutoForwardDiff,
377+
x,
378+
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
379+
) where {F,C}
380+
if x isa Vector && y isa Vector
381+
(; config) = old_prep
382+
(yduals, xduals) = config.duals
383+
resize!(yduals, length(y))
384+
resize!(xduals, length(x))
385+
return old_prep
386+
else
387+
return DI.prepare_jacobian(f!, y, backend, x, contexts...)
388+
end
389+
end
390+
355391
function DI.value_and_jacobian(
356392
f!::F,
357393
y,

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, N::Integer)
1+
function DI.pick_batchsize(::AutoForwardDiff{nothing}, N::Integer)
22
chunksize = ForwardDiff.pickchunksize(N)
33
return DI.BatchSizeSettings{chunksize}(N)
44
end
55

6-
function DI.BatchSizeSettings(::AutoForwardDiff{chunksize}, N::Integer) where {chunksize}
6+
function DI.pick_batchsize(::AutoForwardDiff{chunksize}, N::Integer) where {chunksize}
77
return DI.BatchSizeSettings{chunksize}(N)
88
end
99

DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/DifferentiationInterfacePolyesterForwardDiffExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ end
1313

1414
DI.check_available(::AutoPolyesterForwardDiff) = true
1515

16-
function DI.BatchSizeSettings(backend::AutoPolyesterForwardDiff, x::AbstractArray)
17-
return DI.BatchSizeSettings(single_threaded(backend), x)
16+
function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, x::AbstractArray)
17+
return DI.pick_batchsize(single_threaded(backend), x)
1818
end
1919

20-
function DI.BatchSizeSettings(backend::AutoPolyesterForwardDiff, N::Integer)
21-
return DI.BatchSizeSettings(single_threaded(backend), N)
20+
function DI.pick_batchsize(backend::AutoPolyesterForwardDiff, N::Integer)
21+
return DI.pick_batchsize(single_threaded(backend), N)
2222
end
2323

2424
function DI.threshold_batchsize(

DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,25 @@ end
1414

1515
DI.ismutable_array(::Type{<:SArray}) = false
1616

17-
function DI.BatchSizeSettings(::DI.AutoSimpleFiniteDiff{nothing}, x::StaticArray)
17+
function DI.pick_batchsize(::DI.AutoSimpleFiniteDiff{nothing}, x::StaticArray)
1818
return DI.BatchSizeSettings{length(x),true,true}(length(x))
1919
end
2020

21-
function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, x::StaticArray)
21+
function DI.pick_batchsize(::AutoForwardDiff{nothing}, x::StaticArray)
2222
return DI.BatchSizeSettings{length(x),true,true}(length(x))
2323
end
2424

25-
function DI.BatchSizeSettings(::AutoEnzyme, x::StaticArray)
25+
function DI.pick_batchsize(::AutoEnzyme, x::StaticArray)
2626
return DI.BatchSizeSettings{length(x),true,true}(length(x))
2727
end
2828

29-
function DI.BatchSizeSettings(
29+
function DI.pick_batchsize(
3030
::DI.AutoSimpleFiniteDiff{chunksize}, x::StaticArray
3131
) where {chunksize}
3232
return DI.BatchSizeSettings{chunksize}(Val(length(x)))
3333
end
3434

35-
function DI.BatchSizeSettings(
36-
::AutoForwardDiff{chunksize}, x::StaticArray
37-
) where {chunksize}
35+
function DI.pick_batchsize(::AutoForwardDiff{chunksize}, x::StaticArray) where {chunksize}
3836
return DI.BatchSizeSettings{chunksize}(Val(length(x)))
3937
end
4038

DifferentiationInterface/src/fallbacks/change_prep.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,22 @@ for op in [
4343
if op in (:derivative, :gradient, :jacobian)
4444
# 1-arg
4545
@eval function $prep_op!(
46-
f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
46+
f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
4747
) where {F,C}
4848
return $prep_op(f, backend, x, contexts...)
4949
end
5050
op == :gradient && continue
5151
# 2-arg
5252
@eval function $prep_op!(
53-
f!::F, y, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
53+
f!::F, y, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
5454
) where {F,C}
5555
return $prep_op(f!, y, backend, x, contexts...)
5656
end
5757

5858
elseif op in (:second_derivative, :hessian)
5959
# 1-arg
6060
@eval function $prep_op!(
61-
f::F, ::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
61+
f::F, old_prep::$P, backend::AbstractADType, x, contexts::Vararg{Context,C}
6262
) where {F,C}
6363
return $prep_op(f, backend, x, contexts...)
6464
end
@@ -67,7 +67,7 @@ for op in [
6767
# 1-arg
6868
@eval function $prep_op!(
6969
f::F,
70-
::$P,
70+
old_prep::$P,
7171
backend::AbstractADType,
7272
x,
7373
seed::NTuple,
@@ -96,7 +96,7 @@ for op in [
9696
@eval function $prep_op!(
9797
f!::F,
9898
y,
99-
::$P,
99+
old_prep::$P,
100100
backend::AbstractADType,
101101
x,
102102
seed::NTuple,

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ function prepare_jacobian(
100100
perf = pushforward_performance(backend)
101101
# type-unstable
102102
if perf isa PushforwardFast
103-
batch_size_settings = BatchSizeSettings(backend, x)
103+
batch_size_settings = pick_batchsize(backend, x)
104104
else
105-
batch_size_settings = BatchSizeSettings(backend, y)
105+
batch_size_settings = pick_batchsize(backend, y)
106106
end
107107
# function barrier
108108
return _prepare_jacobian_aux(
@@ -116,9 +116,9 @@ function prepare_jacobian(
116116
perf = pushforward_performance(backend)
117117
# type-unstable
118118
if perf isa PushforwardFast
119-
batch_size_settings = BatchSizeSettings(backend, x)
119+
batch_size_settings = pick_batchsize(backend, x)
120120
else
121-
batch_size_settings = BatchSizeSettings(backend, y)
121+
batch_size_settings = pick_batchsize(backend, y)
122122
end
123123
# function barrier
124124
return _prepare_jacobian_aux(

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ abstract type FromPrimitive <: AbstractADType end
33
check_available(fromprim::FromPrimitive) = check_available(fromprim.backend)
44
inplace_support(fromprim::FromPrimitive) = inplace_support(fromprim.backend)
55

6-
function pick_batchsize(fromprim::FromPrimitive, x_or_N::Union{AbstractArray,Integer})
7-
return pick_batchsize(fromprim.backend, x_or_N)
6+
function pick_batchsize(fromprim::FromPrimitive, N::Integer)
7+
return pick_batchsize(fromprim.backend, N)
88
end
99

1010
struct AutoReverseFromPrimitive{B} <: FromPrimitive

0 commit comments

Comments
 (0)