96
96
)
97
97
end
98
98
99
- ȳ = rand_tangent (fnorm (x))
99
+ ȳ = rand_tangent (fnorm (x, p ))
100
100
@test extern (rrule (fnorm, zero (x), p)[2 ](ȳ)[2 ]) ≈ zero (x)
101
101
@test rrule (fnorm, x, p)[2 ](Zero ())[2 ] isa Zero
102
102
end
103
103
@testset " norm($fdual (::Vector{$T }), p)" for
104
104
T in (Float64, ComplexF64),
105
105
fdual in (adjoint, transpose)
106
+
107
+ x = fdual (randn (T, 3 ))
106
108
p = 2.5
107
- n = 3
108
- x = fdual (randn (T, n))
109
- y = norm (x, p)
110
- x̄ = rand_tangent (x)
111
- ȳ = rand_tangent (y)
112
- p̄ = rand_tangent (p)
113
- test_rrule (norm, ȳ, (x, x̄), (p, p̄))
109
+
110
+ test_rrule (norm, x, p)
111
+ ȳ = rand_tangent (norm (x, p))
114
112
@test extern (rrule (norm, x, p)[2 ](ȳ)[2 ]) isa typeof (x)
115
113
end
116
114
@testset " norm(x::$T , p)" for T in (Float64, ComplexF64)
117
115
@testset " p = $p " for p in (- 1.0 , 2.0 , 2.5 )
118
- x = randn (T)
119
- y = norm (x, p)
120
- ẋ, ṗ = rand_tangent .((x, p))
121
- x̄, p̄, ȳ = rand_tangent .((x, p, y))
122
- test_frule (norm, (x, ẋ), (p, ṗ))
123
- test_rrule (norm, ȳ, (x, x̄), (p, p̄))
124
- _, back = rrule (norm, x, p)
116
+ test_frule (norm, randn (T), p)
117
+ test_rrule (norm, randn (T), p)
118
+
119
+ _, back = rrule (norm, randn (T), p)
125
120
@test back (Zero ()) == (NO_FIELDS, Zero (), Zero ())
126
121
end
127
122
@testset " p = 0" begin
@@ -143,23 +138,14 @@ end
143
138
144
139
@testset " normalize" begin
145
140
@testset " x::Vector{$T }" for T in (Float64, ComplexF64)
146
- n = 3
147
- x = randn (T, n)
148
- y = normalize (x)
149
- x̄ = rand_tangent (x)
150
- ȳ = rand_tangent (y)
151
- test_rrule (normalize, ȳ, (x, x̄))
141
+ x = randn (T, 3 )
142
+ test_rrule (normalize, x)
152
143
@test rrule (normalize, x)[2 ](Zero ()) === (NO_FIELDS, Zero ())
153
144
end
154
145
@testset " x::Vector{$T }, p=$p " for T in (Float64, ComplexF64),
155
146
p in (1.0 , 2.0 , - Inf , Inf , 2.5 ) # skip p=0, since FD is unstable
156
- n = 3
157
- x = randn (T, n)
158
- y = normalize (x, p)
159
- x̄ = rand_tangent (x)
160
- ȳ = rand_tangent (y)
161
- p̄ = rand_tangent (p)
162
- test_rrule (normalize, ȳ, (x, x̄), (p, p̄))
147
+ x = randn (T, 3 )
148
+ test_rrule (normalize, x, p)
163
149
@test rrule (normalize, x, p)[2 ](Zero ()) === (NO_FIELDS, Zero (), Zero ())
164
150
end
165
151
end
0 commit comments