@@ -39,6 +39,7 @@ additional constraints:
3939- `rtol`: Relative tolerance for `isapprox`.
4040- `atol`: Absolute tolerance for `isapprox`.
4141- `testset_name`: Name to use for a testset in which all tests are evaluated.
42+ - `output_tangent`: Optional final tangent to provide at the beginning of the reverse-mode differentiation
4243
4344# Examples
4445
@@ -76,8 +77,8 @@ function test_reverse(
7677 rtol:: Real = 1e-9 ,
7778 atol:: Real = 1e-9 ,
7879 testset_name= nothing ,
79- runtime_activity:: Bool = false
80- )
80+ runtime_activity:: Bool = false ,
81+ output_tangent = nothing )
8182 call_with_captured_kwargs (f, xs... ) = f (xs... ; fkwargs... )
8283 if testset_name === nothing
8384 testset_name = " test_reverse: $f with return activity $ret_activity on $(_string_activity (args)) "
@@ -92,12 +93,12 @@ function test_reverse(
9293 y = fcopy (args_copy... ; deepcopy (fkwargs)... )
9394 # generate tangent for output
9495 if ! _any_batch_duplicated (ret_activity, map (typeof, activities)... )
95- ȳ = ret_activity <: Const ? zero_tangent (y) : rand_tangent (rng, y)
96+ ȳ = isnothing (output_tangent) ? ( ret_activity <: Const ? zero_tangent (y) : rand_tangent (rng, y)) : output_tangent
9697 else
9798 batch_size = _batch_size (ret_activity, map (typeof, activities)... )
9899 ks = ntuple (Symbol ∘ string, batch_size)
99100 ȳ = ntuple (batch_size) do _
100- return ret_activity <: Const ? zero_tangent (y) : rand_tangent (y )
101+ return isnothing (output_tangent) ? ( ret_activity <: Const ? zero_tangent (y) : rand_tangent (rng, y)) : copy (output_tangent )
101102 end
102103 end
103104 # call finitedifferences, avoid mutating original arguments
0 commit comments