Skip to content

Commit 4a59a3a

Browse files
authored
feat: compute Scenario results with a reference backend (#839)
1 parent 9d4e5dc commit 4a59a3a

File tree

6 files changed

+134
-13
lines changed

6 files changed

+134
-13
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...main)
8+
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.6...main)
9+
10+
## [0.7.6](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.5...DifferentiationInterface-v0.7.6)
11+
12+
### Fixed
13+
14+
- Put test deps into `test/Project.toml` ([#840](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/840))
15+
- Set up `pre-commit` ([#837](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/837))
916

1017
### Fixed
1118

DifferentiationInterfaceTest/CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,18 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8-
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.10.0...main)
8+
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.10.1...main)
9+
10+
## [0.10.1](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.10.0...DifferentiationInterfaceTest-v0.10.1)
11+
12+
### Added
13+
14+
- Compute Scenario results with a reference backend ([#839](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/839))
915

1016
### Fixed
1117

18+
- Put test deps into `test/Project.toml` ([#840](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/840))
19+
- Set up `pre-commit` ([#837](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/837))
1220
- Bump compat for SparseConnectivityTracer v1 ([#823](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823))
1321

1422
## [0.10.0](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterfaceTest-v0.9.6...DifferentiationInterfaceTest-v0.10.0)

DifferentiationInterfaceTest/docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DifferentiationInterfaceTest
1313

1414
```@docs
1515
Scenario
16+
compute_results
1617
test_differentiation
1718
benchmark_differentiation
1819
```

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ include("tests/allocs_eval.jl")
135135
include("test_differentiation.jl")
136136

137137
export FIRST_ORDER, SECOND_ORDER
138-
export Scenario
138+
export Scenario, compute_results
139139
export test_differentiation, benchmark_differentiation
140140
export DifferentiationBenchmarkDataRow
141141

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,3 +366,91 @@ closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
366366
constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens)
367367
cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples)
368368
constantorcachify(scens::AbstractVector{<:Scenario}) = constantorcachify.(scens)
369+
370+
## Compute results with backend
371+
372+
get_res1(::Val, args...) = nothing
373+
get_res2(::Val, args...) = nothing
374+
375+
function get_res1(::Val{:derivative}, f, backend::AbstractADType, x, contexts...)
376+
return derivative(f, backend, x, contexts...)
377+
end
378+
function get_res1(::Val{:derivative}, f!, y, backend::AbstractADType, x, contexts...)
379+
return derivative(f!, y, backend, x, contexts...)
380+
end
381+
function get_res1(::Val{:gradient}, f, backend::AbstractADType, x, contexts...)
382+
return gradient(f, backend, x, contexts...)
383+
end
384+
function get_res1(::Val{:jacobian}, f, backend::AbstractADType, x, contexts...)
385+
return jacobian(f, backend, x, contexts...)
386+
end
387+
function get_res1(::Val{:jacobian}, f!, y, backend::AbstractADType, x, contexts...)
388+
return jacobian(f!, y, backend, x, contexts...)
389+
end
390+
function get_res1(::Val{:second_derivative}, f, backend::AbstractADType, x, contexts...)
391+
return derivative(f, backend, x, contexts...)
392+
end
393+
function get_res1(::Val{:hessian}, f, backend::AbstractADType, x, contexts...)
394+
return gradient(f, backend, x, contexts...)
395+
end
396+
397+
function get_res2(::Val{:second_derivative}, f, backend::AbstractADType, x, contexts...)
398+
return second_derivative(f, backend, x, contexts...)
399+
end
400+
function get_res2(::Val{:hessian}, f, backend::AbstractADType, x, contexts...)
401+
return hessian(f, backend, x, contexts...)
402+
end
403+
404+
function get_res1(::Val{:pushforward}, f, backend::AbstractADType, x, t, contexts...)
405+
return pushforward(f, backend, x, t, contexts...)
406+
end
407+
function get_res1(::Val{:pushforward}, f!, y, backend::AbstractADType, x, t, contexts...)
408+
return pushforward(f!, y, backend, x, t, contexts...)
409+
end
410+
function get_res1(::Val{:pullback}, f, backend::AbstractADType, x, t, contexts...)
411+
return pullback(f, backend, x, t, contexts...)
412+
end
413+
function get_res1(::Val{:pullback}, f!, y, backend::AbstractADType, x, t, contexts...)
414+
return pullback(f!, y, backend, x, t, contexts...)
415+
end
416+
function get_res1(::Val{:hvp}, f, backend::AbstractADType, x, t, contexts...)
417+
return gradient(f, backend, x, contexts...)
418+
end
419+
420+
function get_res2(::Val{:hvp}, f, backend::AbstractADType, x, t, contexts...)
421+
return hvp(f, backend, x, t, contexts...)
422+
end
423+
424+
"""
425+
compute_results(scen::Scenario, backend::AbstractADType)
426+
427+
Return a scenario identical to `scen` but where the first- and second-order results `res1` and `res2` have been computed with the given differentiation `backend`.
428+
429+
Useful for comparison of outputs between backends.
430+
"""
431+
function compute_results(
432+
scen::Scenario{op,pl_op,pl_fun}, backend::AbstractADType
433+
) where {op,pl_op,pl_fun}
434+
(; f, y, x, t, contexts, prep_args, name) = deepcopy(scen)
435+
if pl_fun == :in
436+
if isnothing(t)
437+
new_res1 = get_res1(Val(op), f, y, backend, x, contexts...)
438+
new_res2 = get_res2(Val(op), f, y, backend, x, contexts...)
439+
else
440+
new_res1 = get_res1(Val(op), f, y, backend, x, t, contexts...)
441+
new_res2 = get_res2(Val(op), f, y, backend, x, t, contexts...)
442+
end
443+
else
444+
if isnothing(t)
445+
new_res1 = get_res1(Val(op), f, backend, x, contexts...)
446+
new_res2 = get_res2(Val(op), f, backend, x, contexts...)
447+
else
448+
new_res1 = get_res1(Val(op), f, backend, x, t, contexts...)
449+
new_res2 = get_res2(Val(op), f, backend, x, t, contexts...)
450+
end
451+
end
452+
new_scen = Scenario{op,pl_op,pl_fun}(;
453+
f, x, y, t, contexts, res1=new_res1, res2=new_res2, prep_args, name
454+
)
455+
return new_scen
456+
end
Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,34 @@
11
using DifferentiationInterface
22
using DifferentiationInterfaceTest
3+
using DifferentiationInterfaceTest: default_scenarios
34
using ForwardDiff: ForwardDiff
45
using Test
56

6-
scen = Scenario{:gradient,:out}(
7-
sum, zeros(10); res1=ones(10), name="My pretty little scenario"
8-
)
9-
@test string(scen) == "My pretty little scenario"
7+
@testset "Naming" begin
8+
scen = Scenario{:gradient,:out}(
9+
sum, zeros(10); res1=ones(10), name="My pretty little scenario"
10+
)
11+
@test string(scen) == "My pretty little scenario"
1012

11-
testset = test_differentiation(
12-
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
13-
)
13+
testset = test_differentiation(
14+
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
15+
)
1416

15-
data = benchmark_differentiation(
16-
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
17-
)
17+
data = benchmark_differentiation(
18+
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
19+
)
20+
end;
21+
22+
@testset "Compute results" begin
23+
scens = default_scenarios()
24+
new_scens = map(s -> compute_results(s, AutoForwardDiff()), scens)
25+
26+
isapprox_robust(x, y) = isapprox(x, y)
27+
isapprox_robust(x::Nothing, y::Nothing) = true
28+
isapprox_robust(x::NTuple, y::NTuple) = all(map(isapprox, x, y))
29+
30+
for (sa, sb) in zip(scens, new_scens)
31+
@test isapprox_robust(sa.res1, sb.res1)
32+
@test isapprox_robust(sa.res2, sb.res2)
33+
end
34+
end;

0 commit comments

Comments
 (0)