Skip to content

Commit 653ae98

Browse files
authored
remove rand_tangent from ChainRulesTestUtils and use the one from FiniteDifferences (#147)
1 parent fe996bc commit 653ae98

File tree

7 files changed

+8
-96
lines changed

7 files changed

+8
-96
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.6.9"
3+
version = "0.6.10"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ Test.DefaultTestSet("test_scalar: relu at -0.5", Any[Test.DefaultTestSet("with t
116116
## Specifying Tangents
117117
[`test_frule`](@ref) and [`test_rrule`](@ref) allow you to specify the tangents used for testing.
118118
This is done by passing in `x ⊢ Δx`, where `x` is the primal and `Δx` is the tangent, in the place of the primal inputs.
119-
If this is not done the tangent will be automatically generated via [`ChainRulesTestUtils.rand_tangent`](@ref).
119+
If this is not done the tangent will be automatically generated via `FiniteDifferences.rand_tangent`.
120120
A special case of this is that if you specify it as `x ⊢ DoesNotExist()` then finite differencing will not be used on that input.
121121
Similarly, by setting the `output_tangent` keyword argument, you can specify the tangent for the primal output.
122122

123-
This can be useful when the default provided [`ChainRulesTestUtils.rand_tangent`](@ref) doesn't produce the desired tangent for your type.
123+
This can be useful when the default provided `FiniteDifferences.rand_tangent` doesn't produce the desired tangent for your type.
124124
For example the default tangent for an `Int` is `DoesNotExist()`.
125125
Which is correct e.g. when the `Int` represents a discrete integer like in indexing.
126126
But if you are testing something where the `Int` is actually a special case of a real number, then you would want to specify the tangent as a `Float64`.
@@ -161,7 +161,3 @@ which should have passed the test.
161161
Modules = [ChainRulesTestUtils]
162162
Private = false
163163
```
164-
165-
```@docs
166-
ChainRulesTestUtils.rand_tangent
167-
```

src/ChainRulesTestUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ using LinearAlgebra
99
using Random
1010
using Test
1111

12+
import FiniteDifferences: rand_tangent
13+
1214
const _fdm = central_fdm(5, 1; max_range=1e-2)
1315

1416
export TestIterator

src/generate_tangent.jl

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
Use this in the place of a tangent/cotangent in [`test_frule`](@ref) or
55
[`test_rrule`](@ref) to have that tangent/cotangent generated automatically based on the
6-
primal. Uses [`rand_tangent`](@ref)
6+
primal. Uses `FiniteDifferences.rand_tangent`.
77
"""
88
struct Auto end
99

@@ -38,47 +38,3 @@ This function is idempotent. If you pass it a `PrimalAndTangent` it doesn't chan
3838
"""
3939
auto_primal_and_tangent(primal; rng=Random.GLOBAL_RNG) = primal rand_tangent(rng, primal)
4040
auto_primal_and_tangent(both::PrimalAndTangent; kwargs...) = both
41-
42-
"""
43-
rand_tangent([rng::AbstractRNG,] x)
44-
45-
Returns a randomly generated tangent vector appropriate for the primal value `x`.
46-
"""
47-
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)
48-
49-
function rand_tangent(rng::AbstractRNG, x::Union{Symbol, AbstractChar, AbstractString})
50-
return DoesNotExist()
51-
end
52-
53-
rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist()
54-
55-
rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)
56-
57-
# ref: https://github.com/JuliaLang/julia/issues/17629
58-
rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng))
59-
60-
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
61-
62-
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
63-
return Composite{T}(rand_tangent.(Ref(rng), x)...)
64-
end
65-
66-
function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
67-
return Composite{T}(; map(x -> rand_tangent(rng, x), xs)...)
68-
end
69-
70-
function rand_tangent(rng::AbstractRNG, x::T) where {T}
71-
if !isstructtype(T)
72-
throw(ArgumentError("Non-struct types are not supported by this fallback."))
73-
end
74-
75-
field_names = fieldnames(T)
76-
if length(field_names) > 0
77-
tangents = map(field_names) do field_name
78-
rand_tangent(rng, getfield(x, field_name))
79-
end
80-
return Composite{T}(; NamedTuple{field_names}(tangents)...)
81-
else
82-
return NO_FIELDS
83-
end
84-
end

test/generate_tangent.jl

Lines changed: 0 additions & 42 deletions
This file was deleted.

test/iterator.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
@testset "rand_tangent" begin
9393
data = randn(2, 3, 4)
9494
iter = TestIterator(data, Base.SizeUnknown(), Base.EltypeUnknown())
95-
∂iter = rand_tangent(iter)
95+
∂iter = FiniteDifferences.rand_tangent(iter)
9696
@test ∂iter isa typeof(iter)
9797
@test size(∂iter.data) == size(iter.data)
9898
@test eltype(∂iter.data) === eltype(iter.data)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
using ChainRulesCore
22
using ChainRulesTestUtils
3+
using FiniteDifferences
34
using LinearAlgebra
45
using Random
56
using Test
67

78
@testset "ChainRulesTestUtils.jl" begin
89
include("meta_testing_tools.jl")
9-
include("generate_tangent.jl")
1010
include("iterator.jl")
1111
include("check_result.jl")
1212
include("testers.jl")

0 commit comments

Comments
 (0)