Skip to content

Commit a492fa4

Browse files
authored
Add optional output_tangent kwarg to test_reverse (#2588)
1 parent 2519c73 commit a492fa4

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

lib/EnzymeTestUtils/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "EnzymeTestUtils"
22
uuid = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
33
authors = ["Seth Axen <[email protected]>", "William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
4-
version = "0.2.2"
4+
version = "0.2.3"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"

lib/EnzymeTestUtils/src/test_reverse.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ additional constraints:
3939
- `rtol`: Relative tolerance for `isapprox`.
4040
- `atol`: Absolute tolerance for `isapprox`.
4141
- `testset_name`: Name to use for a testset in which all tests are evaluated.
42+
- `output_tangent`: Optional final tangent to provide at the beginning of the reverse-mode differentiation
4243
4344
# Examples
4445
@@ -76,8 +77,8 @@ function test_reverse(
7677
rtol::Real=1e-9,
7778
atol::Real=1e-9,
7879
testset_name=nothing,
79-
runtime_activity::Bool=false
80-
)
80+
runtime_activity::Bool=false,
81+
output_tangent=nothing)
8182
call_with_captured_kwargs(f, xs...) = f(xs...; fkwargs...)
8283
if testset_name === nothing
8384
testset_name = "test_reverse: $f with return activity $ret_activity on $(_string_activity(args))"
@@ -92,12 +93,12 @@ function test_reverse(
9293
y = fcopy(args_copy...; deepcopy(fkwargs)...)
9394
# generate tangent for output
9495
if !_any_batch_duplicated(ret_activity, map(typeof, activities)...)
95-
= ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y)
96+
= isnothing(output_tangent) ? (ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y)) : output_tangent
9697
else
9798
batch_size = _batch_size(ret_activity, map(typeof, activities)...)
9899
ks = ntuple(Symbol string, batch_size)
99100
= ntuple(batch_size) do _
100-
return ret_activity <: Const ? zero_tangent(y) : rand_tangent(y)
101+
return isnothing(output_tangent) ? (ret_activity <: Const ? zero_tangent(y) : rand_tangent(rng, y)) : copy(output_tangent)
101102
end
102103
end
103104
# call finitedifferences, avoid mutating original arguments

0 commit comments

Comments
 (0)