|
| 1 | +module ChainRulesTestUtils |
| 2 | + |
| 3 | +using ChainRulesCore |
| 4 | +using ChainRulesCore: frule, rrule |
| 5 | +using ChainRulesCore: AbstractDifferential |
| 6 | +using FiniteDifferences |
| 7 | +using Test |
| 8 | + |
| 9 | +const _fdm = central_fdm(5, 1) |
| 10 | + |
| 11 | +export test_scalar, frule_test, rrule_test, isapprox, generate_well_conditioned_matrix |
| 12 | + |
| 13 | +Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...) = error("Tried to differentiate w.r.t. a `DoesNotExist`") |
| 14 | +Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...) = isapprox(extern(d_ad), d_fd; kwargs...) |
| 15 | + |
| 16 | +function _make_fdm_call(fdm, f, ȳ, xs, ignores) |
| 17 | + sig = Expr(:tuple) |
| 18 | + call = Expr(:call, f) |
| 19 | + newxs = Any[] |
| 20 | + arginds = Int[] |
| 21 | + i = 1 |
| 22 | + for (x, ignore) in zip(xs, ignores) |
| 23 | + if ignore |
| 24 | + push!(call.args, x) |
| 25 | + else |
| 26 | + push!(call.args, Symbol(:x, i)) |
| 27 | + push!(sig.args, Symbol(:x, i)) |
| 28 | + push!(newxs, x) |
| 29 | + push!(arginds, i) |
| 30 | + end |
| 31 | + i += 1 |
| 32 | + end |
| 33 | + fdexpr = :(j′vp($fdm, $sig -> $call, $ȳ, $(newxs...))) |
| 34 | + fd = eval(fdexpr) |
| 35 | + fd isa Tuple || (fd = (fd,)) |
| 36 | + args = Any[nothing for _ in 1:length(xs)] |
| 37 | + for (dx, ind) in zip(fd, arginds) |
| 38 | + args[ind] = dx |
| 39 | + end |
| 40 | + return (args...,) |
| 41 | +end |
| 42 | + |
| 43 | +# Useful for LinearAlgebra tests |
| 44 | +function generate_well_conditioned_matrix(rng, N) |
| 45 | + A = randn(rng, N, N) |
| 46 | + return A * A' + I |
| 47 | +end |
| 48 | + |
| 49 | +""" |
| 50 | + test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) |
| 51 | +
|
| 52 | +Given a function `f` with scalar input and scalar output, perform finite differencing checks, |
| 53 | +at input point `x` to confirm that there are correct `frule` and `rrule`s provided. |
| 54 | +
|
| 55 | +# Arguments |
| 56 | +- `f`: Function for which the `frule` and `rrule` should be tested. |
| 57 | +- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). |
| 58 | +
|
| 59 | +All keyword arguments except for `fdm` is passed to `isapprox`. |
| 60 | +""" |
| 61 | +function test_scalar(f, x; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) |
| 62 | + ensure_not_running_on_functor(f, "test_scalar") |
| 63 | + |
| 64 | + r_res = rrule(f, x) |
| 65 | + f_res = frule(f, x, Zero(), 1) |
| 66 | + @test r_res !== nothing # Check the rule was defined |
| 67 | + @test f_res !== nothing |
| 68 | + r_fx, prop_rule = r_res |
| 69 | + f_fx, f_∂x = f_res |
| 70 | + @testset "$f at $x, $(nameof(rule))" for (rule, fx, ∂x) in ( |
| 71 | + (rrule, r_fx, prop_rule(1)), |
| 72 | + (frule, f_fx, f_∂x) |
| 73 | + ) |
| 74 | + @test fx == f(x) # Check we still get the normal value, right |
| 75 | + |
| 76 | + if rule == rrule |
| 77 | + ∂self, ∂x = ∂x |
| 78 | + @test ∂self === NO_FIELDS |
| 79 | + end |
| 80 | + @test isapprox(∂x, fdm(f, x); rtol=rtol, atol=atol, kwargs...) |
| 81 | + end |
| 82 | +end |
| 83 | + |
| 84 | +function ensure_not_running_on_functor(f, name) |
| 85 | + # if x itself is a Type, then it is a constructor, thus not a functor. |
| 86 | + # This also catchs UnionAll constructors which have a `:var` and `:body` fields |
| 87 | + f isa Type && return |
| 88 | + |
| 89 | + if fieldcount(typeof(f)) > 0 |
| 90 | + throw(ArgumentError( |
| 91 | + "$name cannot be used on closures/functors (such as $f)" |
| 92 | + )) |
| 93 | + end |
| 94 | +end |
| 95 | + |
| 96 | +""" |
| 97 | + frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) |
| 98 | +
|
| 99 | +# Arguments |
| 100 | +- `f`: Function for which the `frule` should be tested. |
| 101 | +- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). |
| 102 | +- `ẋ`: differential w.r.t. `x` (should generally be set randomly). |
| 103 | +
|
| 104 | +All keyword arguments except for `fdm` are passed to `isapprox`. |
| 105 | +""" |
| 106 | +function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) |
| 107 | + return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm, kwargs...) |
| 108 | +end |
| 109 | + |
| 110 | +function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) |
| 111 | + ensure_not_running_on_functor(f, "frule_test") |
| 112 | + xs, ẋs = collect(zip(xẋs...)) |
| 113 | + Ω, dΩ_ad = frule(f, xs..., NO_FIELDS, ẋs...) |
| 114 | + @test f(xs...) == Ω |
| 115 | + |
| 116 | + # Correctness testing via finite differencing. |
| 117 | + dΩ_fd = jvp(fdm, xs->f(xs...), (xs, ẋs)) |
| 118 | + @test isapprox( |
| 119 | + collect(extern.(dΩ_ad)), # Use collect so can use vector equality |
| 120 | + collect(dΩ_fd); |
| 121 | + rtol=rtol, |
| 122 | + atol=atol, |
| 123 | + kwargs... |
| 124 | + ) |
| 125 | +end |
| 126 | + |
| 127 | + |
| 128 | +""" |
| 129 | + rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...) |
| 130 | +
|
| 131 | +# Arguments |
| 132 | +- `f`: Function to which rule should be applied. |
| 133 | +- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly). |
| 134 | + Should be same structure as `f(x)` (so if multiple returns should be a tuple) |
| 135 | +- `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain). |
| 136 | +- `x̄`: currently accumulated adjoint (should generally be set randomly). |
| 137 | +
|
| 138 | +All keyword arguments except for `fdm` are passed to `isapprox`. |
| 139 | +""" |
| 140 | +function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) |
| 141 | + ensure_not_running_on_functor(f, "rrule_test") |
| 142 | + |
| 143 | + # Check correctness of evaluation. |
| 144 | + fx, pullback = rrule(f, x) |
| 145 | + @test collect(fx) ≈ collect(f(x)) # use collect so can do vector equality |
| 146 | + (∂self, x̄_ad) = if fx isa Tuple |
| 147 | + # If the function returned multiple values, |
| 148 | + # then it must have multiple seeds for propagating backwards |
| 149 | + pullback(ȳ...) |
| 150 | + else |
| 151 | + pullback(ȳ) |
| 152 | + end |
| 153 | + |
| 154 | + @test ∂self === NO_FIELDS # No internal fields |
| 155 | + # Correctness testing via finite differencing. |
| 156 | + x̄_fd = j′vp(fdm, f, ȳ, x) |
| 157 | + @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) |
| 158 | +end |
| 159 | + |
| 160 | +# case where `f` takes multiple arguments |
| 161 | +function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...) |
| 162 | + ensure_not_running_on_functor(f, "rrule_test") |
| 163 | + |
| 164 | + # Check correctness of evaluation. |
| 165 | + xs, x̄s = collect(zip(xx̄s...)) |
| 166 | + y, pullback = rrule(f, xs...) |
| 167 | + @test f(xs...) == y |
| 168 | + |
| 169 | + @assert !(isa(ȳ, Thunk)) |
| 170 | + ∂s = pullback(ȳ) |
| 171 | + ∂self = ∂s[1] |
| 172 | + x̄s_ad = ∂s[2:end] |
| 173 | + @test ∂self === NO_FIELDS |
| 174 | + |
| 175 | + # Correctness testing via finite differencing. |
| 176 | + x̄s_fd = j′vp(fdm, f, ȳ, xs...) |
| 177 | + map(x̄s_ad, x̄s_fd) do x̄_ad, x̄_fd |
| 178 | + @test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...) |
| 179 | + end |
| 180 | +end |
| 181 | + |
| 182 | +end # module |
0 commit comments