Skip to content

Commit 19803d1

Browse files
authored
feat: allow naming scenarios and test sets (#712)
* feat: allow naming scenarios and test sets * docstrings
1 parent 2896511 commit 19803d1

File tree

6 files changed

+86
-21
lines changed

6 files changed

+86
-21
lines changed

DifferentiationInterfaceTest/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.9.3"
4+
version = "0.9.4"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ function Base.zero(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
1515
res1=myzero(scen.res1),
1616
res2=myzero(scen.res2),
1717
smaller=isnothing(scen.smaller) ? nothing : zero(scen.smaller),
18+
name=isnothing(scen.name) ? nothing : scen.name * " [zero]",
1819
)
1920
end
2021

@@ -39,6 +40,7 @@ function change_function(
3940
else
4041
change_function(scen.smaller, new_f; keep_smaller=false)
4142
end,
43+
name=isnothing(scen.name) ? nothing : scen.name * " [new function]",
4244
)
4345
end
4446

@@ -63,6 +65,7 @@ function batchify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
6365
res1=new_res1,
6466
res2,
6567
smaller=isnothing(smaller) ? nothing : batchify(smaller),
68+
name=isnothing(scen.name) ? nothing : scen.name * " [batchified]",
6669
)
6770
elseif op == :hvp
6871
new_tang = (only(tang), -only(tang))
@@ -76,6 +79,7 @@ function batchify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
7679
res1,
7780
res2=new_res2,
7881
smaller=isnothing(smaller) ? nothing : batchify(smaller),
82+
name=isnothing(scen.name) ? nothing : scen.name * " [batchified]",
7983
)
8084
end
8185
end
@@ -160,6 +164,7 @@ function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
160164
res1=mymultiply(scen.res1, a),
161165
res2=mymultiply(scen.res2, a),
162166
smaller=isnothing(scen.smaller) ? nothing : constantify(scen.smaller),
167+
name=isnothing(scen.name) ? nothing : scen.name * " [constantified]",
163168
)
164169
end
165170

@@ -213,6 +218,7 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
213218
res1=scen.res1,
214219
res2=scen.res2,
215220
smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller),
221+
name=isnothing(scen.name) ? nothing : scen.name * " [cachified]",
216222
)
217223
end
218224

DifferentiationInterfaceTest/src/scenarios/scenario.jl

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ This generic type should never be used directly: use the specific constructor co
1313
1414
# Constructors
1515
16-
Scenario{op,pl_op}(f, x; tang, contexts, res1, res2)
17-
Scenario{op,pl_op}(f!, y, x; tang, contexts, res1, res2)
16+
Scenario{op,pl_op}(f, x; tang, contexts, res1, res2, name)
17+
Scenario{op,pl_op}(f!, y, x; tang, contexts, res1, res2, name)
1818
1919
# Fields
2020
@@ -37,32 +37,57 @@ struct Scenario{op,pl_op,pl_fun,F,X,Y,T<:Union{Nothing,NTuple},C<:Tuple,R1,R2,S}
3737
res2::R2
3838
"private field (not part of the public API) containing a variant of the scenario to test preparation resizing"
3939
smaller::S
40+
"name of the scenario for display in test sets and dataframes"
41+
name::Union{String,Nothing}
4042
end
4143

4244
function Scenario{op,pl_op,pl_fun}(
43-
f::F; x::X, y::Y, tang::T, contexts::C, res1::R1, res2::R2, smaller::S=nothing
45+
f::F;
46+
x::X,
47+
y::Y,
48+
tang::T,
49+
contexts::C,
50+
res1::R1,
51+
res2::R2,
52+
smaller::S=nothing,
53+
name=nothing,
4454
) where {op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,S<:Union{Nothing,Scenario}}
4555
@assert smaller isa Union{Nothing,Scenario{op,pl_op,pl_fun,F,X,Y,T,C,R1,R2}}
4656
return Scenario{op,pl_op,pl_fun,F,X,Y,T,C,R1,R2,S}(
47-
f, x, y, tang, contexts, res1, res2, smaller
57+
f, x, y, tang, contexts, res1, res2, smaller, name
4858
)
4959
end
5060

5161
function Scenario{op,pl_op}(
52-
f, x; tang=nothing, contexts=(), res1=nothing, res2=nothing, smaller=nothing
62+
f,
63+
x;
64+
tang=nothing,
65+
contexts=(),
66+
res1=nothing,
67+
res2=nothing,
68+
smaller=nothing,
69+
name=nothing,
5370
) where {op,pl_op}
5471
@assert op in ALL_OPS
5572
@assert pl_op in (:in, :out)
5673
y = f(x, map(unwrap, contexts)...)
57-
return Scenario{op,pl_op,:out}(f; x, y, tang, contexts, res1, res2, smaller)
74+
return Scenario{op,pl_op,:out}(f; x, y, tang, contexts, res1, res2, smaller, name)
5875
end
5976

6077
function Scenario{op,pl_op}(
61-
f!, y, x; tang=nothing, contexts=(), res1=nothing, res2=nothing, smaller=nothing
78+
f!,
79+
y,
80+
x;
81+
tang=nothing,
82+
contexts=(),
83+
res1=nothing,
84+
res2=nothing,
85+
smaller=nothing,
86+
name=nothing,
6287
) where {op,pl_op}
6388
@assert op in ALL_OPS
6489
@assert pl_op in (:in, :out)
65-
return Scenario{op,pl_op,:in}(f!; x, y, tang, contexts, res1, res2, smaller)
90+
return Scenario{op,pl_op,:in}(f!; x, y, tang, contexts, res1, res2, smaller, name)
6691
end
6792

6893
Base.:(==)(scen1::Scenario, scen2::Scenario) = false
@@ -85,7 +110,8 @@ function Base.:(==)(
85110
)
86111
eq_res1 = scen1.res1 == scen2.res1
87112
eq_res2 = scen1.res2 == scen2.res2
88-
return (eq_x && eq_y && eq_tang && eq_contexts && eq_res1 && eq_res2)
113+
eq_name = scen1.name == scen2.name
114+
return (eq_x && eq_y && eq_tang && eq_contexts && eq_res1 && eq_res2 && eq_name)
89115
end
90116

91117
operator(::Scenario{op}) where {op} = op
@@ -123,12 +149,16 @@ end
123149
function Base.show(
124150
io::IO, scen::Scenario{op,pl_op,pl_fun,F,X,Y,T}
125151
) where {op,pl_op,pl_fun,F,X,Y,T}
126-
print(io, "Scenario{$(repr(op)),$(repr(pl_op))} $(string(scen.f)) : $X -> $Y")
127-
if op in (:pushforward, :pullback, :hvp)
128-
print(io, " ($(length(scen.tang)) tangents)")
129-
end
130-
if length(scen.contexts) > 0
131-
print(io, " ($(length(scen.contexts)) contexts)")
152+
if isnothing(scen.name)
153+
print(io, "Scenario{$(repr(op)),$(repr(pl_op))} $(string(scen.f)) : $X -> $Y")
154+
if op in (:pushforward, :pullback, :hvp)
155+
print(io, " ($(length(scen.tang)) tangents)")
156+
end
157+
if length(scen.contexts) > 0
158+
print(io, " ($(length(scen.contexts)) contexts)")
159+
end
160+
else
161+
print(io, scen.name)
132162
end
133163
return nothing
134164
end

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ This function always creates and runs a `@testset`, though its contents may vary
1717
1818
# Keyword arguments
1919
20+
- `testset_name=nothing`: how to display the test set
21+
2022
**Test categories:**
2123
2224
- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario
@@ -66,6 +68,7 @@ Each setting tests/benchmarks a different subset of calls:
6668
function test_differentiation(
6769
backends::Vector{<:AbstractADType},
6870
scenarios::Vector{<:Scenario}=default_scenarios();
71+
testset_name::Union{String,Nothing}=nothing,
6972
# test categories
7073
correctness::Bool=true,
7174
type_stability::Symbol=:none,
@@ -105,11 +108,15 @@ function test_differentiation(
105108
scenarios = filter(s -> !(operator(s) in excluded), scenarios)
106109
scenarios = sort(scenarios; by=s -> (operator(s), string(s.f)))
107110

108-
title_additions =
109-
(correctness ? " + correctness" : "") *
110-
((type_stability != :none) ? " + type stability" : "") *
111-
((benchmark != :none) ? " + benchmarks" : "")
112-
title = "Testing" * title_additions[3:end]
111+
if isnothing(testset_name)
112+
title_additions =
113+
(correctness ? " + correctness" : "") *
114+
((type_stability != :none) ? " + type stability" : "") *
115+
((benchmark != :none) ? " + benchmarks" : "")
116+
title = "Testing" * title_additions[3:end]
117+
else
118+
title = testset_name
119+
end
113120

114121
benchmark_data = DifferentiationBenchmarkDataRow[]
115122

@@ -222,6 +229,7 @@ Specifying the set of scenarios is mandatory for this function.
222229
function benchmark_differentiation(
223230
backends,
224231
scenarios::Vector{<:Scenario};
232+
testset_name::Union{String,Nothing}=nothing,
225233
benchmark::Symbol=:prepared,
226234
excluded::Vector{Symbol}=Symbol[],
227235
logging::Bool=false,
@@ -235,6 +243,7 @@ function benchmark_differentiation(
235243
return test_differentiation(
236244
backends,
237245
scenarios;
246+
testset_name,
238247
correctness=false,
239248
type_stability=:none,
240249
allocations=:none,

DifferentiationInterfaceTest/test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ GROUP = get(ENV, "JULIA_DIT_TEST_GROUP", "All")
1414
@testset verbose = true "Formalities" begin
1515
include("formalities.jl")
1616
end
17+
@testset verbose = true "Scenarios" begin
18+
include("scenario.jl")
19+
end
1720
end
1821

1922
if GROUP == "Zero" || GROUP == "All"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using DifferentiationInterface
2+
using DifferentiationInterfaceTest
3+
using ForwardDiff: ForwardDiff
4+
using Test
5+
6+
scen = Scenario{:gradient,:out}(
7+
sum, zeros(10); res1=ones(10), name="My pretty little scenario"
8+
)
9+
@test string(scen) == "My pretty little scenario"
10+
11+
testset = test_differentiation(
12+
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
13+
)
14+
15+
data = benchmark_differentiation(
16+
AutoForwardDiff(), [scen]; testset_name="My amazing test set"
17+
)

0 commit comments

Comments
 (0)