8080# Keyword Arguments
8181 - `output_tangent` tangent to test accumulation of derivatives against
8282 should be a differential for the output of `f`. Is set automatically if not provided.
83- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
84- transform the passed argument tangents into alternative tangents that should be tested.
85- Note that the alternative tangents are only tested for not erroring when passed to
86- frule. Testing for correctness using finite differencing can be done using a
87- separate `test_frule` call, e.g. for testing a `ZeroTangent()` for correctness:
88- `test_frule(f, x ⊢ ZeroTangent(); tangent_transforms=[])`.
8983 - `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
9084 - `frule_f=frule`: Function with an `frule`-like API that is tested (defaults to
9185 `frule`). Used for testing gradients from AD systems.
@@ -104,7 +98,6 @@ function test_frule(
10498 f,
10599 args... ;
106100 output_tangent= Auto (),
107- tangent_transforms= TRANSFORMS_TO_ALT_TANGENTS,
108101 fdm= _fdm,
109102 frule_f= ChainRulesCore. frule,
110103 check_inferred:: Bool = true ,
@@ -143,25 +136,9 @@ function test_frule(
143136
144137 acc = output_tangent isa Auto ? rand_tangent (Ω) : output_tangent
145138 _test_add!!_behaviour (acc, dΩ_ad; isapprox_kwargs... )
146-
147- # test that rules work for other tangents
148- _test_frule_alt_tangents (
149- call_on_copy, frule_f, config, tangent_transforms, tangents, primals, acc;
150- isapprox_kwargs...
151- )
152139 end # top-level testset
153140end
154141
155- function _test_frule_alt_tangents (
156- call, frule_f, config, tangent_transforms, tangents, primals, acc;
157- isapprox_kwargs...
158- )
159- @testset " ȧrgs = $(_string_typeof (tsf .(tangents))) " for tsf in tangent_transforms
160- _, dΩ = call (frule_f, config, tsf .(tangents), primals... )
161- _test_add!!_behaviour (acc, dΩ; isapprox_kwargs... )
162- end
163- end
164-
165142"""
166143 test_rrule([config::RuleConfig,] f, args...; kwargs...)
167144
176153# Keyword Arguments
177154 - `output_tangent` the seed to propagate backward for testing (technically a cotangent).
178155 should be a differential for the output of `f`. Is set automatically if not provided.
179- - `tangent_transforms=TRANSFORMS_TO_ALT_TANGENTS`: a vector of functions that
180- transform the passed `output_tangent` into alternative tangents that should be tested.
181- Note that the alternative tangents are only tested for not erroring when passed to
182- rrule. Testing for correctness using finite differencing can be done using a
183- separate `test_rrule` call, e.g. for testing a `ZeroTangent()` for correctness:
184- `test_rrule(f, args...; output_tangent=ZeroTangent(), tangent_transforms=[])`.
156+ - `check_thunked_output_tangent=true`: also checks that passing a thunked version of the
157+ output tangent to the pullback returns the same result.
185158 - `fdm::FiniteDifferenceMethod`: the finite differencing method to use.
186159 - `rrule_f=rrule`: Function with an `rrule`-like API that is tested (defaults to `rrule`).
187160 Used for testing gradients from AD systems.
@@ -200,7 +173,7 @@ function test_rrule(
200173 f,
201174 args... ;
202175 output_tangent= Auto (),
203- tangent_transforms = TRANSFORMS_TO_ALT_TANGENTS ,
176+ check_thunked_output_tangent = true ,
204177 fdm= _fdm,
205178 rrule_f= ChainRulesCore. rrule,
206179 check_inferred:: Bool = true ,
@@ -267,21 +240,10 @@ function test_rrule(
267240 end
268241 end
269242
270- # test other tangents don't error when passed to the pullback
271- _test_rrule_alt_tangents (pullback, tangent_transforms, ȳ, accum_cotangents)
272- end # top-level testset
273- end
274-
275- function _test_rrule_alt_tangents (
276- pullback, tangent_transforms, ȳ, accum_cotangents;
277- isapprox_kwargs...
278- )
279- @testset " ȳ = $(_string_typeof (tsf (ȳ))) " for tsf in tangent_transforms
280- ad_cotangents = pullback (tsf (ȳ))
281- for (accum_cotangent, ad_cotangent) in zip (accum_cotangents, ad_cotangents)
282- _test_add!!_behaviour (accum_cotangent, ad_cotangent; isapprox_kwargs... )
243+ if check_thunked_output_tangent
244+ test_approx (ad_cotangents, pullback (@thunk (ȳ)), " pulling back a thunk" )
283245 end
284- end
246+ end # top-level testset
285247end
286248
287249"""
0 commit comments