Skip to content

Commit d2e12ae

Browse files
authored
feat!: specify preparation arguments in DIT Scenario (#786)
* feat!: specify preparation arguments in DIT `Scenario` * Fix * Fixes * Fixes * Fixes * Fix static arrays * Fix * Fix sparse and complex * All works except HVP * Fix tangents for prep same point * Fixes * Update DifferentiationInterfaceTest/src/scenarios/scenario.jl
1 parent c2bd64f commit d2e12ae

File tree

27 files changed

+872
-839
lines changed

27 files changed

+872
-839
lines changed

DifferentiationInterface/src/utils/prep.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function check_prep(
198198
if SIG != EXEC_SIG
199199
throw(
200200
PreparationMismatchError(
201-
SIG, EXEC_SIG; format=[:f, :backend, :x, :tang, :contexts]
201+
SIG, EXEC_SIG; format=[:f, :backend, :x, :t, :contexts]
202202
),
203203
)
204204
end
@@ -213,7 +213,7 @@ function check_prep(
213213
if SIG != EXEC_SIG
214214
throw(
215215
PreparationMismatchError(
216-
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :tang, :contexts]
216+
SIG, EXEC_SIG; format=[:f!, :y, :backend, :x, :t, :contexts]
217217
),
218218
)
219219
end

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ function differentiatewith_scenarios()
1616
DIT.function_place(scen) == :out
1717
end
1818
good_scens = map(bad_scens) do scen
19-
DIT.change_function(
20-
scen, DifferentiateWith(scen.f, AutoFiniteDiff()); keep_smaller=false
21-
)
19+
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
2220
end
2321
return good_scens
2422
end

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ end
2727
include_cachified=true,
2828
include_constantorcachified=true,
2929
use_tuples=true,
30+
include_smaller=true,
3031
);
3132
excluded=[:second_derivative, :hvp],
3233
logging=LOGGING,

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ end
4141
include_cachified=true,
4242
include_constantorcachified=true,
4343
use_tuples=true,
44+
include_smaller=true,
4445
);
4546
logging=LOGGING,
4647
)

DifferentiationInterface/test/Core/Internals/signature.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ end
9898
- exec: Nothing
9999
- backend: ✅
100100
- x: ✅
101-
- tang: ✅
101+
- t: ✅
102102
- contexts: ✅
103103
""" pushforward(nothing, prep, backend, x, (x,), Constant(c))
104104
end
@@ -119,7 +119,7 @@ end
119119
- y: ✅
120120
- backend: ✅
121121
- x: ✅
122-
- tang: ✅
122+
- t: ✅
123123
- contexts: ✅
124124
""" pushforward(nothing, y, prep, backend, x, (x,))
125125
end

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
@testset "Dense" begin
6464
test_differentiation(
6565
vcat(backends, second_order_backends),
66-
default_scenarios(; include_constantified=true);
66+
default_scenarios(; include_constantified=true, include_smaller=true);
6767
logging=LOGGING,
6868
)
6969

DifferentiationInterfaceTest/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Changed
11+
12+
- Specify preparation arguments in DIT Scenario ([#786])
13+
1014
## [0.9.6] - 2025-03-28
1115

1216
### Added
@@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1822
[unreleased]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.6...main
1923
[0.9.6]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.5...DifferentiationInterfaceTest-v0.9.6
2024

25+
[#786]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/786
2126
[#749]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/749
2227
[#748]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/748
2328
[#745]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/745

DifferentiationInterfaceTest/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.9.6"
4+
version = "0.10.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -44,7 +44,7 @@ AllocCheck = "0.2"
4444
Chairmarks = "1.2.1"
4545
ComponentArrays = "0.15"
4646
DataFrames = "1.6.1"
47-
DifferentiationInterface = "0.6.0"
47+
DifferentiationInterface = "0.6.53"
4848
DocStringExtensions = "0.8,0.9"
4949
ExplicitImports = "1.10.1"
5050
FiniteDiff = "2.27.0"

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,13 @@ function comp_to_num_scenarios_onearg(x::ComponentVector; dx::AbstractVector, dy
3333
append!(
3434
scens,
3535
[
36-
DIT.Scenario{:pullback,pl_op}(f, x; tang=(dy,), res1=(dx_from_dy,)),
36+
DIT.Scenario{:pullback,pl_op}(f, x, (dy,); res1=(dx_from_dy,)),
3737
DIT.Scenario{:gradient,pl_op}(f, x; res1=grad),
3838
],
3939
)
4040
end
4141
for pl_op in (:out,)
42-
append!(
43-
scens, [DIT.Scenario{:pushforward,pl_op}(f, x; tang=(dx,), res1=(dy_from_dx,))]
44-
)
42+
append!(scens, [DIT.Scenario{:pushforward,pl_op}(f, x, (dx,); res1=(dy_from_dx,))])
4543
end
4644
return scens
4745
end

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
9292
g = gradient_finite_differences(square_loss, model, x)
9393

9494
scen = DIT.Scenario{:gradient,:out}(
95-
square_loss, model; contexts=(DI.Constant(x),), res1=g
95+
square_loss,
96+
model,
97+
DI.Constant(x);
98+
prep_args=(x=model, contexts=(DI.Constant(x),)),
99+
res1=g,
96100
)
97101
push!(scens, scen)
98102

@@ -163,7 +167,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
163167
Flux.trainmode!(model)
164168
g = gradient_finite_differences(square_loss, model, x)
165169
scen = DIT.Scenario{:gradient,:out}(
166-
square_loss, model; contexts=(DI.Constant(x),), res1=g
170+
square_loss,
171+
model,
172+
DI.Constant(x);
173+
prep_args=(; x=model, contexts=(DI.Constant(x),)),
174+
res1=g,
167175
)
168176
push!(scens, scen)
169177
end
@@ -191,7 +199,11 @@ function DIT.flux_scenarios(rng::AbstractRNG=default_rng())
191199
Flux.trainmode!(model)
192200
g = gradient_finite_differences(square_loss_iterated, model, x)
193201
scen = DIT.Scenario{:gradient,:out}(
194-
square_loss_iterated, model; contexts=(DI.Constant(x),), res1=g
202+
square_loss_iterated,
203+
model,
204+
DI.Constant(x);
205+
prep_args=(; x=model, contexts=(DI.Constant(x),)),
206+
res1=g,
195207
)
196208
push!(scens, scen)
197209
end

0 commit comments

Comments
 (0)