|
23 | 23 | kwargs = NamedTuple()
|
24 | 24 | end
|
25 | 25 |
|
26 |
| - y = fnorm(x) |
27 |
| - x̄ = rand_tangent(x) |
28 |
| - ȳ = rand_tangent(y) |
29 |
| - |
30 | 26 | fnorm === LinearAlgebra.norm2 && @testset "frule" begin
|
31 |
| - ẋ = rand_tangent(x) |
32 |
| - frule_test(fnorm, (x, ẋ)) |
| 27 | + test_frule(fnorm, x) |
33 | 28 | end
|
34 | 29 | @testset "rrule" begin
|
35 |
| - rrule_test(fnorm, ȳ, (x, x̄); kwargs...) |
| 30 | + test_rrule(fnorm, x; kwargs...) |
36 | 31 | x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
|
37 |
| - rrule_test(fnorm, ȳ, (MT(x), MT(x̄)); kwargs...) |
| 32 | + test_rrule(fnorm, MT(x); kwargs...) |
38 | 33 | end
|
| 34 | + |
| 35 | + ȳ = rand_tangent(fnorm(x)) |
39 | 36 | @test extern(rrule(fnorm, zero(x))[2](ȳ)[2]) ≈ zero(x)
|
40 | 37 | @test rrule(fnorm, x)[2](Zero())[2] isa Zero
|
41 | 38 | end
|
|
45 | 42 | sz in [(0,), (3,), (3, 3), (3, 2, 1)]
|
46 | 43 |
|
47 | 44 | x = randn(T, sz)
|
48 |
| - y = norm(x) |
49 |
| - ẋ = rand_tangent(x) |
50 |
| - x̄ = rand_tangent(x) |
51 |
| - ȳ = rand_tangent(y) |
52 | 45 |
|
53 | 46 | @testset "frule" begin
|
54 |
| - frule_test(norm, (x, ẋ)) |
| 47 | + test_frule(norm, x) |
55 | 48 | @test frule((Zero(), Zero()), norm, x)[2] isa Zero
|
| 49 | + |
| 50 | + ẋ = rand_tangent(x) |
56 | 51 | @test iszero(frule((Zero(), ẋ), norm, zero(x))[2])
|
57 | 52 | end
|
58 | 53 | @testset "rrule" begin
|
59 |
| - rrule_test(norm, ȳ, (x, x̄)) |
| 54 | + test_rrule(norm, x) |
60 | 55 | x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
|
61 | 56 | # we don't check inference on older julia versions. Improvements to
|
62 | 57 | # inference mean on 1.5+ it works, and that is good enough
|
63 |
| - rrule_test(norm, ȳ, (MT(x), MT(x̄)); check_inferred=VERSION>=v"1.5") |
| 58 | + test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5") |
64 | 59 | end
|
| 60 | + |
| 61 | + ȳ = rand_tangent(norm(x)) |
65 | 62 | @test extern(rrule(norm, zero(x))[2](ȳ)[2]) ≈ zero(x)
|
66 | 63 | @test rrule(norm, x)[2](Zero())[2] isa Zero
|
67 | 64 | end
|
|
90 | 87 | kwargs = NamedTuple()
|
91 | 88 | end
|
92 | 89 |
|
93 |
| - y = fnorm(x, p) |
94 |
| - x̄ = rand_tangent(x) |
95 |
| - ȳ = rand_tangent(y) |
96 |
| - p̄ = rand_tangent(p) |
97 | 90 |
|
98 |
| - rrule_test(fnorm, ȳ, (x, x̄), (p, p̄); kwargs...) |
| 91 | + test_rrule(fnorm, x, p; kwargs...) |
99 | 92 | x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
|
100 |
| - rrule_test( |
101 |
| - fnorm, ȳ, (MT(x), MT(x̄)), (p, p̄); |
| 93 | + test_rrule(fnorm, MT(x), p; |
102 | 94 | #Don't check inference on old julia, what matters is that works on new
|
103 | 95 | check_inferred=VERSION>=v"1.5", kwargs...
|
104 | 96 | )
|
105 | 97 | end
|
| 98 | + |
| 99 | + ȳ = rand_tangent(fnorm(x)) |
106 | 100 | @test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) ≈ zero(x)
|
107 | 101 | @test rrule(fnorm, x, p)[2](Zero())[2] isa Zero
|
108 | 102 | end
|
|
116 | 110 | x̄ = rand_tangent(x)
|
117 | 111 | ȳ = rand_tangent(y)
|
118 | 112 | p̄ = rand_tangent(p)
|
119 |
| - rrule_test(norm, ȳ, (x, x̄), (p, p̄)) |
| 113 | + test_rrule(norm, ȳ, (x, x̄), (p, p̄)) |
120 | 114 | @test extern(rrule(norm, x, p)[2](ȳ)[2]) isa typeof(x)
|
121 | 115 | end
|
122 | 116 | @testset "norm(x::$T, p)" for T in (Float64, ComplexF64)
|
|
125 | 119 | y = norm(x, p)
|
126 | 120 | ẋ, ṗ = rand_tangent.((x, p))
|
127 | 121 | x̄, p̄, ȳ = rand_tangent.((x, p, y))
|
128 |
| - frule_test(norm, (x, ẋ), (p, ṗ)) |
129 |
| - rrule_test(norm, ȳ, (x, x̄), (p, p̄)) |
| 122 | + test_frule(norm, (x, ẋ), (p, ṗ)) |
| 123 | + test_rrule(norm, ȳ, (x, x̄), (p, p̄)) |
130 | 124 | _, back = rrule(norm, x, p)
|
131 | 125 | @test back(Zero()) == (NO_FIELDS, Zero(), Zero())
|
132 | 126 | end
|
|
154 | 148 | y = normalize(x)
|
155 | 149 | x̄ = rand_tangent(x)
|
156 | 150 | ȳ = rand_tangent(y)
|
157 |
| - rrule_test(normalize, ȳ, (x, x̄)) |
| 151 | + test_rrule(normalize, ȳ, (x, x̄)) |
158 | 152 | @test rrule(normalize, x)[2](Zero()) === (NO_FIELDS, Zero())
|
159 | 153 | end
|
160 | 154 | @testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64),
|
|
165 | 159 | x̄ = rand_tangent(x)
|
166 | 160 | ȳ = rand_tangent(y)
|
167 | 161 | p̄ = rand_tangent(p)
|
168 |
| - rrule_test(normalize, ȳ, (x, x̄), (p, p̄)) |
| 162 | + test_rrule(normalize, ȳ, (x, x̄), (p, p̄)) |
169 | 163 | @test rrule(normalize, x, p)[2](Zero()) === (NO_FIELDS, Zero(), Zero())
|
170 | 164 | end
|
171 | 165 | end
|
0 commit comments