|
75 | 75 | - `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
|
76 | 76 | - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
|
77 | 77 | - `ẋ`: differential w.r.t. `x`, will be generated automatically if not provided
|
78 |
| - Non-differentiable arguments, such as indices, should have `ẋ` set as `nothing`. |
| 78 | + Non-differentiable arguments, such as indices, should have `ẋ` set as `DoesNotExist()`. |
79 | 79 |
|
80 | 80 | # Keyword Arguments
|
81 | 81 | - `output_tangent` tangent to test accumulation of derivatives against
|
@@ -114,7 +114,16 @@ function test_frule(
|
114 | 114 | Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
|
115 | 115 | check_equal(Ω_ad, Ω; isapprox_kwargs...)
|
116 | 116 |
|
117 |
| - ẋs_is_ignored = ẋs .== nothing |
| 117 | + # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 |
| 118 | + ẋs_is_ignored = isa.(ẋs, Union{Nothing, DoesNotExist}) |
| 119 | + if any(ẋs .== nothing) |
| 120 | + Base.depwarn( |
| 121 | + "test_frule(f, k ⊢ nothing) is deprecated, use " * |
| 122 | + "test_frule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks", |
| 123 | + :test_frule |
| 124 | + ) |
| 125 | + end |
| 126 | + |
118 | 127 | # Correctness testing via finite differencing.
|
119 | 128 | dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
|
120 | 129 | check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)
|
|
134 | 143 | - `inputs` either the primal inputs `x`, or primals and their tangents: `x ⊢ ẋ`
|
135 | 144 | - `x`: input at which to evaluate `f` (should generally be set to an arbitary point in the domain).
|
136 | 145 | - `x̄`: currently accumulated cotangent, will be generated automatically if not provided
|
137 |
| - Non-differentiable arguments, such as indices, should have `x̄` set as `nothing`. |
| 146 | + Non-differentiable arguments, such as indices, should have `x̄` set as `DoesNotExist()`. |
138 | 147 |
|
139 | 148 | # Keyword Arguments
|
140 | 149 | - `output_tangent` the seed to propagate backward for testing (techncally a cotangent).
|
@@ -182,10 +191,19 @@ function test_rrule(
|
182 | 191 | @test ∂self === NO_FIELDS # No internal fields
|
183 | 192 |
|
184 | 193 | # Correctness testing via finite differencing.
|
185 |
| - x̄s_is_dne = accumulated_x̄ .== nothing |
| 194 | + # TODO: remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113 |
| 195 | + x̄s_is_dne = isa.(accumulated_x̄, Union{Nothing, DoesNotExist}) |
| 196 | + if any(accumulated_x̄ .== nothing) |
| 197 | + Base.depwarn( |
| 198 | + "test_rrule(f, k ⊢ nothing) is deprecated, use " * |
| 199 | + "test_rrule(f, k ⊢ DoesNotExist()) instead for non-differentiable ks", |
| 200 | + :test_rrule |
| 201 | + ) |
| 202 | + end |
| 203 | + |
186 | 204 | x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
|
187 | 205 | for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
|
188 |
| - if accumulated_x̄ === nothing # then we marked this argument as not differentiable |
| 206 | + if accumulated_x̄ isa Union{Nothing, DoesNotExist} # then we marked this argument as not differentiable # TODO remove once #113 |
189 | 207 | @assert x̄_fd === nothing # this is how `_make_j′vp_call` works
|
190 | 208 | @test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
|
191 | 209 | else
|
|
0 commit comments