8080# Keyword Arguments
8181 - `output_tangent` tangent to test accumulation of derivatives against
8282 should be a differential for the output of `f`. Is set automatically if not provided.
83- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84- transform the passed argument tangents into alternative tangents that should be tested.
85- Note that the alternative tangents are only tested for not erroring when passed to
86- frule. Testing for correctness using finite differencing can be done using a
87- separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88- `test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
8983 - `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
9084 - `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
9185 `frule`). Used for testing gradients from AD systems.
@@ -104,7 +98,6 @@ function test_frule(
10498 f,
10599 args... ;
106100 output_tangent= Auto (),
107- tangent_transforms= TRANSFORMS_TO_ALT_TANGENTS,
108101 fdm= _fdm,
109102 frule_f= ChainRulesCore. frule,
110103 check_inferred:: Bool = true ,
@@ -136,41 +129,16 @@ function test_frule(
136129 Ω = call_on_copy (primals... )
137130 test_approx (Ω_ad, Ω; isapprox_kwargs... )
138131
139- # TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
140- is_ignored = isa .(tangents, Union{Nothing,NoTangent})
141- if any (tangents .== nothing )
142- Base. depwarn (
143- " test_frule(f, k ⊢ nothing) is deprecated, use " *
144- " test_frule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
145- :test_frule ,
146- )
147- end
148-
149132 # Correctness testing via finite differencing.
133+ is_ignored = isa .(tangents, NoTangent)
150134 dΩ_fd = _make_jvp_call (fdm, call_on_copy, Ω, primals, tangents, is_ignored)
151135 test_approx (dΩ_ad, dΩ_fd; isapprox_kwargs... )
152136
153137 acc = output_tangent isa Auto ? rand_tangent (Ω) : output_tangent
154138 _test_add!!_behaviour (acc, dΩ_ad; isapprox_kwargs... )
155-
156- # test that rules work for other tangents
157- _test_frule_alt_tangents (
158- call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
159- isapprox_kwargs...
160- )
161139 end # top-level testset
162140end
163141
164- function _test_frule_alt_tangents (
165- call, frule_f, config, tangent_transforms, tangents, primals, acc;
166- isapprox_kwargs...
167- )
168- @testset " ȧrgs = $(_string_typeof (tsf .(tangents))) " for tsf in tangent_transforms
169- _, dΩ = call (frule_f, config, tsf .(tangents), primals... )
170- _test_add!!_behaviour (acc, dΩ; isapprox_kwargs... )
171- end
172- end
173-
174142"""
175143 test_rrule([config::RuleConfig,] f, args...; kwargs...)
176144
185153# Keyword Arguments
186154 - `output_tangent` the seed to propagate backward for testing (technically a cotangent).
187155 should be a differential for the output of `f`. Is set automatically if not provided.
188- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
189- transform the passed `output_tangent` into alternative tangents that should be tested.
190- Note that the alternative tangents are only tested for not erroring when passed to
191- rrule. Testing for correctness using finite differencing can be done using a
192- separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
193- `test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
156+ - `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
157+ output tangent to the pullback returns the same result.
194158 - `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
195159 - `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
196160 Used for testing gradients from AD systems.
@@ -209,7 +173,7 @@ function test_rrule(
209173 f,
210174 args... ;
211175 output_tangent= Auto (),
212- tangent_transforms = TRANSFORMS_TO_ALT_TANGENTS ,
176+ check_thunked_output_tangent = true ,
213177 fdm= _fdm,
214178 rrule_f= ChainRulesCore. rrule,
215179 check_inferred:: Bool = true ,
@@ -254,22 +218,13 @@ function test_rrule(
254218 )
255219
256220 # Correctness testing via finite differencing.
257- # TODO : remove Nothing when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/113
258- is_ignored = isa .(accum_cotangents, Union{Nothing, NoTangent})
259- if any (accum_cotangents .== nothing )
260- Base. depwarn (
261- " test_rrule(f, k ⊢ nothing) is deprecated, use " *
262- " test_rrule(f, k ⊢ NoTangent()) instead for non-differentiable ks" ,
263- :test_rrule ,
264- )
265- end
266-
221+ is_ignored = isa .(accum_cotangents, NoTangent)
267222 fd_cotangents = _make_j′vp_call (fdm, call, ȳ, primals, is_ignored)
268223
269224 for (accum_cotangent, ad_cotangent, fd_cotangent) in zip (
270225 accum_cotangents, ad_cotangents, fd_cotangents
271226 )
272- if accum_cotangent isa Union{Nothing, NoTangent} # then we marked this argument as not differentiable # TODO remove once #113
227+ if accum_cotangent isa NoTangent # then we marked this argument as not differentiable
273228 @assert fd_cotangent === nothing # this is how `_make_j′vp_call` works
274229 ad_cotangent isa ZeroTangent && error (
275230 " The pullback in the rrule should use NoTangent()" *
@@ -285,21 +240,10 @@ function test_rrule(
285240 end
286241 end
287242
288- # test other tangents don't error when passed to the pullback
289- _test_rrule_alt_tangents (pullback, tangent_transforms, ȳ, accum_cotangents)
290- end # top-level testset
291- end
292-
293- function _test_rrule_alt_tangents (
294- pullback, tangent_transforms, ȳ, accum_cotangents;
295- isapprox_kwargs...
296- )
297- @testset " ȳ = $(_string_typeof (tsf (ȳ))) " for tsf in tangent_transforms
298- ad_cotangents = pullback (tsf (ȳ))
299- for (accum_cotangent, ad_cotangent) in zip (accum_cotangents, ad_cotangents)
300- _test_add!!_behaviour (accum_cotangent, ad_cotangent; isapprox_kwargs... )
243+ if check_thunked_output_tangent
244+ test_approx (ad_cotangents, pullback (@thunk (ȳ)), " pulling back a thunk:" )
301245 end
302- end
246+ end # top-level testset
303247end
304248
305249"""
0 commit comments