23
23
kwargs = NamedTuple ()
24
24
end
25
25
26
- y = fnorm (x)
27
- x̄ = rand_tangent (x)
28
- ȳ = rand_tangent (y)
29
-
30
26
fnorm === LinearAlgebra. norm2 && @testset " frule" begin
31
- ẋ = rand_tangent (x)
32
- frule_test (fnorm, (x, ẋ))
27
+ test_frule (fnorm, x)
33
28
end
34
29
@testset " rrule" begin
35
- rrule_test (fnorm, ȳ, (x, x̄) ; kwargs... )
30
+ test_rrule (fnorm, x ; kwargs... )
36
31
x isa Matrix && @testset " $MT " for MT in (Diagonal, UpperTriangular, LowerTriangular)
37
- rrule_test (fnorm, ȳ, ( MT (x), MT (x̄) ); kwargs... )
32
+ test_rrule (fnorm, MT (x); kwargs... )
38
33
end
34
+
35
+ ȳ = rand_tangent (fnorm (x))
39
36
@test extern (rrule (fnorm, zero (x))[2 ](ȳ)[2 ]) ≈ zero (x)
40
37
@test rrule (fnorm, x)[2 ](Zero ())[2 ] isa Zero
41
38
end
45
42
sz in [(0 ,), (3 ,), (3 , 3 ), (3 , 2 , 1 )]
46
43
47
44
x = randn (T, sz)
48
- y = norm (x)
49
- ẋ = rand_tangent (x)
50
- x̄ = rand_tangent (x)
51
- ȳ = rand_tangent (y)
52
45
53
46
@testset " frule" begin
54
- frule_test (norm, (x, ẋ) )
47
+ test_frule (norm, x )
55
48
@test frule ((Zero (), Zero ()), norm, x)[2 ] isa Zero
49
+
50
+ ẋ = rand_tangent (x)
56
51
@test iszero (frule ((Zero (), ẋ), norm, zero (x))[2 ])
57
52
end
58
53
@testset " rrule" begin
59
- rrule_test (norm, ȳ, (x, x̄) )
54
+ test_rrule (norm, x )
60
55
x isa Matrix && @testset " $MT " for MT in (Diagonal, UpperTriangular, LowerTriangular)
61
56
# we don't check inference on older julia versions. Improvements to
62
57
# inference mean on 1.5+ it works, and that is good enough
63
- rrule_test (norm, ȳ, ( MT (x), MT (x̄) ); check_inferred= VERSION >= v " 1.5" )
58
+ test_rrule (norm, MT (x); check_inferred= VERSION >= v " 1.5" )
64
59
end
60
+
61
+ ȳ = rand_tangent (norm (x))
65
62
@test extern (rrule (norm, zero (x))[2 ](ȳ)[2 ]) ≈ zero (x)
66
63
@test rrule (norm, x)[2 ](Zero ())[2 ] isa Zero
67
64
end
90
87
kwargs = NamedTuple ()
91
88
end
92
89
93
- y = fnorm (x, p)
94
- x̄ = rand_tangent (x)
95
- ȳ = rand_tangent (y)
96
- p̄ = rand_tangent (p)
97
90
98
- rrule_test (fnorm, ȳ, ( x, x̄), (p, p̄) ; kwargs... )
91
+ test_rrule (fnorm, x, p ; kwargs... )
99
92
x isa Matrix && @testset " $MT " for MT in (Diagonal, UpperTriangular, LowerTriangular)
100
- rrule_test (
101
- fnorm, ȳ, (MT (x), MT (x̄)), (p, p̄);
93
+ test_rrule (fnorm, MT (x), p;
102
94
# Don't check inference on old julia, what matters is that works on new
103
95
check_inferred= VERSION >= v " 1.5" , kwargs...
104
96
)
105
97
end
98
+
99
+ ȳ = rand_tangent (fnorm (x, p))
106
100
@test extern (rrule (fnorm, zero (x), p)[2 ](ȳ)[2 ]) ≈ zero (x)
107
101
@test rrule (fnorm, x, p)[2 ](Zero ())[2 ] isa Zero
108
102
end
109
103
@testset " norm($fdual (::Vector{$T }), p)" for
110
104
T in (Float64, ComplexF64),
111
105
fdual in (adjoint, transpose)
106
+
107
+ x = fdual (randn (T, 3 ))
112
108
p = 2.5
113
- n = 3
114
- x = fdual (randn (T, n))
115
- y = norm (x, p)
116
- x̄ = rand_tangent (x)
117
- ȳ = rand_tangent (y)
118
- p̄ = rand_tangent (p)
119
- rrule_test (norm, ȳ, (x, x̄), (p, p̄))
109
+
110
+ test_rrule (norm, x, p)
111
+ ȳ = rand_tangent (norm (x, p))
120
112
@test extern (rrule (norm, x, p)[2 ](ȳ)[2 ]) isa typeof (x)
121
113
end
122
114
@testset " norm(x::$T , p)" for T in (Float64, ComplexF64)
123
115
@testset " p = $p " for p in (- 1.0 , 2.0 , 2.5 )
124
- x = randn (T)
125
- y = norm (x, p)
126
- ẋ, ṗ = rand_tangent .((x, p))
127
- x̄, p̄, ȳ = rand_tangent .((x, p, y))
128
- frule_test (norm, (x, ẋ), (p, ṗ))
129
- rrule_test (norm, ȳ, (x, x̄), (p, p̄))
130
- _, back = rrule (norm, x, p)
116
+ test_frule (norm, randn (T), p)
117
+ test_rrule (norm, randn (T), p)
118
+
119
+ _, back = rrule (norm, randn (T), p)
131
120
@test back (Zero ()) == (NO_FIELDS, Zero (), Zero ())
132
121
end
133
122
@testset " p = 0" begin
@@ -149,23 +138,14 @@ end
149
138
150
139
@testset " normalize" begin
151
140
@testset " x::Vector{$T }" for T in (Float64, ComplexF64)
152
- n = 3
153
- x = randn (T, n)
154
- y = normalize (x)
155
- x̄ = rand_tangent (x)
156
- ȳ = rand_tangent (y)
157
- rrule_test (normalize, ȳ, (x, x̄))
141
+ x = randn (T, 3 )
142
+ test_rrule (normalize, x)
158
143
@test rrule (normalize, x)[2 ](Zero ()) === (NO_FIELDS, Zero ())
159
144
end
160
145
@testset " x::Vector{$T }, p=$p " for T in (Float64, ComplexF64),
161
146
p in (1.0 , 2.0 , - Inf , Inf , 2.5 ) # skip p=0, since FD is unstable
162
- n = 3
163
- x = randn (T, n)
164
- y = normalize (x, p)
165
- x̄ = rand_tangent (x)
166
- ȳ = rand_tangent (y)
167
- p̄ = rand_tangent (p)
168
- rrule_test (normalize, ȳ, (x, x̄), (p, p̄))
147
+ x = randn (T, 3 )
148
+ test_rrule (normalize, x, p)
169
149
@test rrule (normalize, x, p)[2 ](Zero ()) === (NO_FIELDS, Zero (), Zero ())
170
150
end
171
151
end
0 commit comments