Skip to content

feat: compute Scenario results with a reference backend #839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DifferentiationInterfaceTest/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ DifferentiationInterfaceTest

```@docs
Scenario
compute_results
test_differentiation
benchmark_differentiation
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ include("tests/allocs_eval.jl")
include("test_differentiation.jl")

export FIRST_ORDER, SECOND_ORDER
export Scenario
export Scenario, compute_results
export test_differentiation, benchmark_differentiation
export DifferentiationBenchmarkDataRow

Expand Down
88 changes: 88 additions & 0 deletions DifferentiationInterfaceTest/src/scenarios/modify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,91 @@ closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens)
cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples)
constantorcachify(scens::AbstractVector{<:Scenario}) = constantorcachify.(scens)

## Compute results with backend

get_res1(::Val, args...) = nothing
get_res2(::Val, args...) = nothing

function get_res1(::Val{:derivative}, f, backend::AbstractADType, x, contexts...)
return derivative(f, backend, x, contexts...)
end
function get_res1(::Val{:derivative}, f!, y, backend::AbstractADType, x, contexts...)
return derivative(f!, y, backend, x, contexts...)
end
function get_res1(::Val{:gradient}, f, backend::AbstractADType, x, contexts...)
return gradient(f, backend, x, contexts...)
end
function get_res1(::Val{:jacobian}, f, backend::AbstractADType, x, contexts...)
return jacobian(f, backend, x, contexts...)
end
function get_res1(::Val{:jacobian}, f!, y, backend::AbstractADType, x, contexts...)
return jacobian(f!, y, backend, x, contexts...)
end
function get_res1(::Val{:second_derivative}, f, backend::AbstractADType, x, contexts...)
return derivative(f, backend, x, contexts...)
end
function get_res1(::Val{:hessian}, f, backend::AbstractADType, x, contexts...)
return gradient(f, backend, x, contexts...)
end

function get_res2(::Val{:second_derivative}, f, backend::AbstractADType, x, contexts...)
return second_derivative(f, backend, x, contexts...)
end
function get_res2(::Val{:hessian}, f, backend::AbstractADType, x, contexts...)
return hessian(f, backend, x, contexts...)
end

function get_res1(::Val{:pushforward}, f, backend::AbstractADType, x, t, contexts...)
return pushforward(f, backend, x, t, contexts...)
end
function get_res1(::Val{:pushforward}, f!, y, backend::AbstractADType, x, t, contexts...)
return pushforward(f!, y, backend, x, t, contexts...)
end
function get_res1(::Val{:pullback}, f, backend::AbstractADType, x, t, contexts...)
return pullback(f, backend, x, t, contexts...)
end
function get_res1(::Val{:pullback}, f!, y, backend::AbstractADType, x, t, contexts...)
return pullback(f!, y, backend, x, t, contexts...)
end
function get_res1(::Val{:hvp}, f, backend::AbstractADType, x, t, contexts...)
return gradient(f, backend, x, contexts...)
end

function get_res2(::Val{:hvp}, f, backend::AbstractADType, x, t, contexts...)
return hvp(f, backend, x, t, contexts...)
end

"""
compute_results(scen::Scenario, backend::AbstractADType)

Return a scenario identical to `scen` but where the first- and second-order results `res1` and `res2` have been computed with the given differentiation `backend`.

Useful for comparison of outputs between backends.
"""
function compute_results(
scen::Scenario{op,pl_op,pl_fun}, backend::AbstractADType
) where {op,pl_op,pl_fun}
(; f, y, x, t, contexts, prep_args, name) = deepcopy(scen)
if pl_fun == :in
if isnothing(t)
new_res1 = get_res1(Val(op), f, y, backend, x, contexts...)
new_res2 = get_res2(Val(op), f, y, backend, x, contexts...)
else
new_res1 = get_res1(Val(op), f, y, backend, x, t, contexts...)
new_res2 = get_res2(Val(op), f, y, backend, x, t, contexts...)
end
else
if isnothing(t)
new_res1 = get_res1(Val(op), f, backend, x, contexts...)
new_res2 = get_res2(Val(op), f, backend, x, contexts...)
else
new_res1 = get_res1(Val(op), f, backend, x, t, contexts...)
new_res2 = get_res2(Val(op), f, backend, x, t, contexts...)
end
end
new_scen = Scenario{op,pl_op,pl_fun}(;
f, x, y, t, contexts, res1=new_res1, res2=new_res2, prep_args, name
)
return new_scen
end
37 changes: 27 additions & 10 deletions DifferentiationInterfaceTest/test/scenario.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
using DifferentiationInterface
using DifferentiationInterfaceTest
using DifferentiationInterfaceTest: default_scenarios
using ForwardDiff: ForwardDiff
using Test

scen = Scenario{:gradient,:out}(
sum, zeros(10); res1=ones(10), name="My pretty little scenario"
)
@test string(scen) == "My pretty little scenario"
@testset "Naming" begin
scen = Scenario{:gradient,:out}(
sum, zeros(10); res1=ones(10), name="My pretty little scenario"
)
@test string(scen) == "My pretty little scenario"

testset = test_differentiation(
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
)
testset = test_differentiation(
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
)

data = benchmark_differentiation(
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
)
data = benchmark_differentiation(
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
)
end;

@testset "Compute results" begin
scens = default_scenarios()
new_scens = map(s -> compute_results(s, AutoForwardDiff()), scens)

isapprox_robust(x, y) = isapprox(x, y)
isapprox_robust(x::Nothing, y::Nothing) = true
isapprox_robust(x::NTuple, y::NTuple) = all(map(isapprox, x, y))

for (sa, sb) in zip(scens, new_scens)
@test isapprox_robust(sa.res1, sb.res1)
@test isapprox_robust(sa.res2, sb.res2)
end
end;
Loading