Skip to content

Commit 74bb967

Browse files
authored
test!: better tests (#173)
* test!: better tests * Fix * Fix * Fix * Coverage
1 parent 474654d commit 74bb967

File tree

11 files changed

+317
-245
lines changed

11 files changed

+317
-245
lines changed

docs/src/api.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@ ImplicitFunction
1919
### Settings
2020

2121
```@docs
22-
IterativeLinearSolver
2322
MatrixRepresentation
2423
OperatorRepresentation
25-
NoPreparation
26-
ForwardPreparation
27-
ReversePreparation
28-
BothPreparation
24+
IterativeLinearSolver
2925
```
3026

3127
## Internals

src/ImplicitDifferentiation.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ include("preparation.jl")
3333
include("implicit_function.jl")
3434
include("execution.jl")
3535

36-
export IterativeLinearSolver
3736
export MatrixRepresentation, OperatorRepresentation
38-
export NoPreparation, ForwardPreparation, ReversePreparation, BothPreparation
37+
export IterativeLinearSolver
3938
export ImplicitFunction
4039

4140
end

src/implicit_function.jl

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,24 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂
2222
conditions;
2323
representation=OperatorRepresentation(),
2424
linear_solver=IterativeLinearSolver(),
25-
backend=nothing,
25+
backends=nothing,
2626
preparation=nothing,
2727
input_example=nothing,
2828
)
2929
3030
## Positional arguments
3131
3232
- `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractArray`, while `z` and `args` can be anything.
33-
- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation
33+
- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation.
3434
3535
## Keyword arguments
3636
37-
- `representation`: either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref)
38-
- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b)` (single solve) and one for `(A, B)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function.
39-
- `backend::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either
40-
- `nothing`, which means that the external autodiff system will be used
41-
- a single object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl)
42-
- a named tuple `(; x, y)` of objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl)
37+
- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented, either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref).
38+
- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b::AbstractVector)` (single solve) and one for `(A, B::AbstractMatrix)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function.
39+
- `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
4340
- `preparation`: either `nothing` or a mode object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl): `ADTypes.ForwardMode()`, `ADTypes.ReverseMode()` or `ADTypes.ForwardOrReverseMode()`.
4441
- `input_example`: either `nothing` or a tuple `(x, args...)` used to prepare differentiation.
45-
- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types. Relaxing this to `strict=Val(false)` can prove necessary when working with custom array types like ComponentArrays.jl, which are not always compatible with iterative linear solvers.
42+
- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types.
4643
"""
4744
struct ImplicitFunction{
4845
F,
@@ -51,7 +48,6 @@ struct ImplicitFunction{
5148
R<:AbstractRepresentation,
5249
B<:Union{
5350
Nothing, #
54-
AbstractADType,
5551
NamedTuple{(:x, :y),<:Tuple{AbstractADType,AbstractADType}},
5652
},
5753
P<:Union{Nothing,AbstractMode},
@@ -90,59 +86,26 @@ function ImplicitFunction(
9086
prep_B = nothing
9187
prep_Bᵀ = nothing
9288
else
93-
real_backends = backends isa AbstractADType ? (; x=backends, y=backends) : backends
9489
x, args = first(input_example), Base.tail(input_example)
9590
y, z = solver(x, args...)
9691
c = conditions(x, y, z, args...)
9792
if preparation isa Union{ForwardMode,ForwardOrReverseMode}
9893
prep_A = prepare_A(
99-
representation,
100-
x,
101-
y,
102-
z,
103-
c,
104-
args...;
105-
conditions,
106-
backend=real_backends.y,
107-
strict,
94+
representation, x, y, z, c, args...; conditions, backend=backends.y, strict
10895
)
10996
prep_B = prepare_B(
110-
representation,
111-
x,
112-
y,
113-
z,
114-
c,
115-
args...;
116-
conditions,
117-
backend=real_backends.x,
118-
strict,
97+
representation, x, y, z, c, args...; conditions, backend=backends.x, strict
11998
)
12099
else
121100
prep_A = nothing
122101
prep_B = nothing
123102
end
124103
if preparation isa Union{ReverseMode,ForwardOrReverseMode}
125104
prep_Aᵀ = prepare_Aᵀ(
126-
representation,
127-
x,
128-
y,
129-
z,
130-
c,
131-
args...;
132-
conditions,
133-
backend=real_backends.y,
134-
strict,
105+
representation, x, y, z, c, args...; conditions, backend=backends.y, strict
135106
)
136107
prep_Bᵀ = prepare_Bᵀ(
137-
representation,
138-
x,
139-
y,
140-
z,
141-
c,
142-
args...;
143-
conditions,
144-
backend=real_backends.x,
145-
strict,
108+
representation, x, y, z, c, args...; conditions, backend=backends.x, strict
146109
)
147110
else
148111
prep_Aᵀ = nothing

src/settings.jl

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ Callable object that can solve linear systems `Ax = b` and `AX = B` in the same
77
88
# Constructor
99
10+
IterativeLinearSolver(; kwargs...)
1011
IterativeLinearSolver{package}(; kwargs...)
1112
1213
The type parameter `package` can be either:
1314
14-
- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl)
15+
- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default)
1516
- `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl)
1617
1718
Keyword arguments are passed on to the respective solver.
@@ -34,7 +35,15 @@ struct IterativeLinearSolver{package,K}
3435
end
3536
end
3637

37-
IterativeLinearSolver() = IterativeLinearSolver{:Krylov}()
38+
function Base.show(io::IO, linear_solver::IterativeLinearSolver{package}) where {package}
39+
print(io, "IterativeLinearSolver{$(repr(package))}(; ")
40+
for (k, v) in pairs(linear_solver.kwargs)
41+
print(io, "$k=$v, ")
42+
end
43+
return print(io, ")")
44+
end
45+
46+
IterativeLinearSolver(; kwargs...) = IterativeLinearSolver{:Krylov}(; kwargs...)
3847

3948
function (solver::IterativeLinearSolver{:Krylov})(A, b::AbstractVector)
4049
x, stats = Krylov.gmres(A, b; solver.kwargs...)
@@ -85,6 +94,7 @@ Specify that the matrix `A` involved in the implicit function theorem should be
8594
8695
# Constructors
8796
97+
OperatorRepresentation(; symmetric=false, hermitian=false)
8898
OperatorRepresentation{package}(; symmetric=false, hermitian=false)
8999
90100
The type parameter `package` can be either:
@@ -108,38 +118,15 @@ struct OperatorRepresentation{package,symmetric,hermitian} <: AbstractRepresenta
108118
end
109119
end
110120

111-
OperatorRepresentation() = OperatorRepresentation{:LinearOperators}()
112-
113-
## Preparation
114-
115-
abstract type AbstractPreparation end
116-
117-
"""
118-
ForwardPreparation
119-
120-
Specify that the derivatives of the conditions should be prepared for subsequent forward-mode differentiation of the implicit function.
121-
"""
122-
struct ForwardPreparation <: AbstractPreparation end
123-
124-
"""
125-
ReversePreparation
126-
127-
Specify that the derivatives of the conditions should be prepared for subsequent reverse-mode differentiation of the implicit function.
128-
"""
129-
struct ReversePreparation <: AbstractPreparation end
130-
131-
"""
132-
BothPreparation
133-
134-
Specify that the derivatives of the conditions should be prepared for subsequent forward- or reverse-mode differentiation of the implicit function.
135-
"""
136-
struct BothPreparation <: AbstractPreparation end
137-
138-
"""
139-
NoPreparation
121+
function Base.show(
122+
io::IO, ::OperatorRepresentation{package,symmetric,hermitian}
123+
) where {package,symmetric,hermitian}
124+
return print(
125+
io,
126+
"OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian)",
127+
)
128+
end
140129

141-
Specify that the derivatives of the conditions should not be prepared for subsequent differentiation of the implicit function.
142-
"""
143-
struct NoPreparation <: AbstractPreparation end
130+
OperatorRepresentation(; kwargs...) = OperatorRepresentation{:LinearOperators}(; kwargs...)
144131

145132
function chainrules_suggested_backend end

test/examples.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1-
@testitem "intro" begin
1+
@testitem "Intro" begin
22
include(joinpath(dirname(@__DIR__), "examples", "0_intro.jl"))
33
end
44

5-
@testitem "basic" begin
5+
@testitem "Basic" begin
66
include(joinpath(dirname(@__DIR__), "examples", "1_basic.jl"))
77
end
88

9-
@testitem "advanced" begin
9+
@testitem "Advanced" begin
1010
include(joinpath(dirname(@__DIR__), "examples", "2_advanced.jl"))
1111
end
1212

13-
@testitem "tricks" begin
13+
@testitem "Tricks" begin
1414
include(joinpath(dirname(@__DIR__), "examples", "3_tricks.jl"))
1515
end

test/formalities.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@ using TestItems
66
using Zygote: Zygote
77
Aqua.test_all(ImplicitDifferentiation; ambiguities=false, undocumented_names=true)
88
end
9+
910
@testitem "Formatting" begin
1011
using JuliaFormatter
1112
@test format(ImplicitDifferentiation; verbose=false, overwrite=false)
1213
end
14+
1315
@testitem "Static checking" begin
1416
using JET
1517
using ForwardDiff: ForwardDiff
1618
using Zygote: Zygote
1719
JET.test_package(ImplicitDifferentiation; target_defined_modules=true)
1820
end
21+
1922
@testitem "Imports" begin
2023
using ExplicitImports
2124
using ForwardDiff: ForwardDiff
@@ -27,6 +30,7 @@ end
2730
@test check_all_qualified_accesses_via_owners(ImplicitDifferentiation) === nothing
2831
@test check_no_self_qualified_accesses(ImplicitDifferentiation) === nothing
2932
end
33+
3034
@testitem "Doctests" begin
3135
using Documenter
3236
Documenter.DocMeta.setdocmeta!(

test/preparation.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
@testitem "Preparation" begin
2+
using ImplicitDifferentiation
3+
using ADTypes
4+
using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode
5+
using ForwardDiff: ForwardDiff
6+
using Zygote: Zygote
7+
using Test
8+
9+
solver(x) = sqrt.(x), nothing
10+
conditions(x, y, z) = y .^ 2 .- x
11+
x = rand(5)
12+
input_example = (x,)
13+
14+
@testset "None" begin
15+
implicit_none = ImplicitFunction(solver, conditions)
16+
@test implicit_none.prep_A === nothing
17+
@test implicit_none.prep_Aᵀ === nothing
18+
@test implicit_none.prep_B === nothing
19+
@test implicit_none.prep_Bᵀ === nothing
20+
end
21+
22+
@testset "ForwardMode" begin
23+
implicit_forward = ImplicitFunction(
24+
solver,
25+
conditions;
26+
preparation=ForwardMode(),
27+
backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()),
28+
input_example,
29+
)
30+
@test implicit_forward.prep_A !== nothing
31+
@test implicit_forward.prep_Aᵀ === nothing
32+
@test implicit_forward.prep_B !== nothing
33+
@test implicit_forward.prep_Bᵀ === nothing
34+
end
35+
36+
@testset "ReverseMode" begin
37+
implicit_reverse = ImplicitFunction(
38+
solver,
39+
conditions;
40+
preparation=ReverseMode(),
41+
backends=(; x=AutoZygote(), y=AutoZygote()),
42+
input_example,
43+
)
44+
@test implicit_reverse.prep_A === nothing
45+
@test implicit_reverse.prep_Aᵀ !== nothing
46+
@test implicit_reverse.prep_B === nothing
47+
@test implicit_reverse.prep_Bᵀ !== nothing
48+
end
49+
50+
@testset "Both" begin
51+
implicit_both = ImplicitFunction(
52+
solver,
53+
conditions;
54+
preparation=ForwardOrReverseMode(),
55+
backends=(; x=AutoForwardDiff(), y=AutoZygote()),
56+
input_example,
57+
)
58+
@test implicit_both.prep_A !== nothing
59+
@test implicit_both.prep_Aᵀ !== nothing
60+
@test implicit_both.prep_B !== nothing
61+
@test implicit_both.prep_Bᵀ !== nothing
62+
end
63+
end

test/printing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using TestItems
2+
3+
@testitem "Settings" begin
4+
@test startswith(string(OperatorRepresentation()), "Operator")
5+
@test startswith(string(IterativeLinearSolver(; atol=1e-5)), "Iterative")
6+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
using TestItemRunner
22

3+
@testmodule TestUtils begin
4+
include("utils.jl")
5+
export Scenario, test_implicit, add_arg_mult
6+
export default_solver, default_conditions
7+
end
8+
39
@run_package_tests

0 commit comments

Comments
 (0)