Skip to content

Commit 23c78a6

Browse files
willtebbuttoxinaboxsethaxen
authored
Automatic tangent generation for simple types (#43)
* Implements for simple types * Adds rng-less default. Co-authored-by: Lyndon White <[email protected]> * Process Lyndon's reviews * Adds docstring * Bump compat * Fix rng broadcasting Co-authored-by: Seth Axen <[email protected]> * Fix rng broadcasting Co-authored-by: Seth Axen <[email protected]> * Adds Bool test * Fixes rng broadcasting * Bumps patch version * Adds struct rand_tangent * Tests nested struct * Checks for isstructtype Co-authored-by: Lyndon White <[email protected]> * Fixes isstructtype * Uses default_rng() Co-authored-by: Lyndon White <[email protected]> * Reverts to GLOBAL_RNG Co-authored-by: Lyndon White <[email protected]> Co-authored-by: Seth Axen <[email protected]>
1 parent 250dfac commit 23c78a6

File tree

5 files changed

+85
-2
lines changed

5 files changed

+85
-2
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesTestUtils"
22
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3-
version = "0.4.2"
3+
version = "0.4.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1212

1313
[compat]
14-
ChainRulesCore = "0.9"
14+
ChainRulesCore = "0.9.1"
1515
Compat = "3"
1616
FiniteDifferences = "0.10"
1717
julia = "1"

src/ChainRulesTestUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ const _fdm = central_fdm(5, 1)
1313

1414
export test_scalar, frule_test, rrule_test, generate_well_conditioned_matrix
1515

16+
include("generate_tangent.jl")
1617
include("to_vec.jl")
1718
include("isapprox.jl")
1819
include("data_generation.jl")

src/generate_tangent.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""
2+
rand_tangent([rng::AbstractRNG,] x)
3+
4+
Returns a randomly generated tangent vector appropriate for the primal value `x`.
5+
"""
6+
rand_tangent(x) = rand_tangent(Random.GLOBAL_RNG, x)
7+
8+
function rand_tangent(rng::AbstractRNG, x::Union{Symbol, AbstractChar, AbstractString})
9+
return DoesNotExist()
10+
end
11+
12+
rand_tangent(rng::AbstractRNG, x::Integer) = DoesNotExist()
13+
14+
rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)
15+
16+
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
17+
18+
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
19+
return Composite{T}(rand_tangent.(Ref(rng), x)...)
20+
end
21+
22+
function rand_tangent(rng::AbstractRNG, xs::T) where {T<:NamedTuple}
23+
return Composite{T}(; map(x -> rand_tangent(rng, x), xs)...)
24+
end
25+
26+
function rand_tangent(rng::AbstractRNG, x::T) where {T}
27+
if !isstructtype(T)
28+
throw(ArgumentError("Non-struct types are not supported by this fallback."))
29+
end
30+
31+
field_names = fieldnames(T)
32+
if length(field_names) > 0
33+
tangents = map(field_names) do field_name
34+
rand_tangent(rng, getfield(x, field_name))
35+
end
36+
return Composite{T}(; NamedTuple{field_names}(tangents)...)
37+
else
38+
return NO_FIELDS
39+
end
40+
end

test/generate_tangent.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
using ChainRulesTestUtils: rand_tangent
2+
3+
# Test struct for `rand_tangent`.
4+
struct Foo
5+
a::Float64
6+
b::Int
7+
c::Any
8+
end
9+
10+
@testset "generate_tangent" begin
11+
rng = MersenneTwister(123456)
12+
13+
foreach([
14+
("hi", DoesNotExist),
15+
('a', DoesNotExist),
16+
(:a, DoesNotExist),
17+
(true, DoesNotExist),
18+
(4, DoesNotExist),
19+
(5.0, Float64),
20+
(5.0 + 0.4im, Complex{Float64}),
21+
(randn(Float32, 3), Vector{Float32}),
22+
(randn(Complex{Float64}, 2), Vector{Complex{Float64}}),
23+
(randn(5, 4), Matrix{Float64}),
24+
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
25+
([randn(5, 4), 4.0], Vector{Any}),
26+
((4.0, ), Composite{Tuple{Float64}}),
27+
((5.0, randn(3)), Composite{Tuple{Float64, Vector{Float64}}}),
28+
((a=4.0, ), Composite{NamedTuple{(:a,), Tuple{Float64}}}),
29+
((a=5.0, b=1), Composite{NamedTuple{(:a, :b), Tuple{Float64, Int}}}),
30+
(sin, typeof(NO_FIELDS)),
31+
(Foo(5.0, 4, rand(rng, 3)), Composite{Foo}),
32+
(Foo(4.0, 3, Foo(5.0, 2, 4)), Composite{Foo}),
33+
]) do (x, T_tangent)
34+
@test rand_tangent(rng, x) isa T_tangent
35+
@test rand_tangent(x) isa T_tangent
36+
@test x + rand_tangent(rng, x) isa typeof(x)
37+
end
38+
39+
# Ensure struct fallback errors for non-struct types.
40+
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
41+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Random
55
using Test
66

77
@testset "ChainRulesTestUtils.jl" begin
8+
include("generate_tangent.jl")
89
include("to_vec.jl")
910
include("isapprox.jl")
1011
include("testers.jl")

0 commit comments

Comments
 (0)