Skip to content

Commit d6e9806

Browse files
authored
Merge pull request #377 from JuliaDiff/ox/tannomr
norm.jl autotangent
2 parents 9c75379 + cc03c81 commit d6e9806

File tree

1 file changed

+29
-49
lines changed

1 file changed

+29
-49
lines changed

test/rulesets/LinearAlgebra/norm.jl

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,16 @@
2323
kwargs = NamedTuple()
2424
end
2525

26-
y = fnorm(x)
27-
= rand_tangent(x)
28-
= rand_tangent(y)
29-
3026
fnorm === LinearAlgebra.norm2 && @testset "frule" begin
31-
= rand_tangent(x)
32-
frule_test(fnorm, (x, ẋ))
27+
test_frule(fnorm, x)
3328
end
3429
@testset "rrule" begin
35-
rrule_test(fnorm, ȳ, (x, x̄); kwargs...)
30+
test_rrule(fnorm, x; kwargs...)
3631
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
37-
rrule_test(fnorm, ȳ, (MT(x), MT(x̄)); kwargs...)
32+
test_rrule(fnorm, MT(x); kwargs...)
3833
end
34+
35+
= rand_tangent(fnorm(x))
3936
@test extern(rrule(fnorm, zero(x))[2](ȳ)[2]) zero(x)
4037
@test rrule(fnorm, x)[2](Zero())[2] isa Zero
4138
end
@@ -45,23 +42,23 @@
4542
sz in [(0,), (3,), (3, 3), (3, 2, 1)]
4643

4744
x = randn(T, sz)
48-
y = norm(x)
49-
= rand_tangent(x)
50-
= rand_tangent(x)
51-
= rand_tangent(y)
5245

5346
@testset "frule" begin
54-
frule_test(norm, (x, ẋ))
47+
test_frule(norm, x)
5548
@test frule((Zero(), Zero()), norm, x)[2] isa Zero
49+
50+
= rand_tangent(x)
5651
@test iszero(frule((Zero(), ẋ), norm, zero(x))[2])
5752
end
5853
@testset "rrule" begin
59-
rrule_test(norm, ȳ, (x, x̄))
54+
test_rrule(norm, x)
6055
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
6156
# we don't check inference on older julia versions. Improvements to
6257
# inference mean on 1.5+ it works, and that is good enough
63-
rrule_test(norm, ȳ, (MT(x), MT(x̄)); check_inferred=VERSION>=v"1.5")
58+
test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5")
6459
end
60+
61+
= rand_tangent(norm(x))
6562
@test extern(rrule(norm, zero(x))[2](ȳ)[2]) zero(x)
6663
@test rrule(norm, x)[2](Zero())[2] isa Zero
6764
end
@@ -90,44 +87,36 @@
9087
kwargs = NamedTuple()
9188
end
9289

93-
y = fnorm(x, p)
94-
= rand_tangent(x)
95-
= rand_tangent(y)
96-
= rand_tangent(p)
9790

98-
rrule_test(fnorm, ȳ, (x, x̄), (p, p̄); kwargs...)
91+
test_rrule(fnorm, x, p; kwargs...)
9992
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
100-
rrule_test(
101-
fnorm, ȳ, (MT(x), MT(x̄)), (p, p̄);
93+
test_rrule(fnorm, MT(x), p;
10294
#Don't check inference on old julia, what matters is that works on new
10395
check_inferred=VERSION>=v"1.5", kwargs...
10496
)
10597
end
98+
99+
= rand_tangent(fnorm(x, p))
106100
@test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) zero(x)
107101
@test rrule(fnorm, x, p)[2](Zero())[2] isa Zero
108102
end
109103
@testset "norm($fdual(::Vector{$T}), p)" for
110104
T in (Float64, ComplexF64),
111105
fdual in (adjoint, transpose)
106+
107+
x = fdual(randn(T, 3))
112108
p = 2.5
113-
n = 3
114-
x = fdual(randn(T, n))
115-
y = norm(x, p)
116-
= rand_tangent(x)
117-
= rand_tangent(y)
118-
= rand_tangent(p)
119-
rrule_test(norm, ȳ, (x, x̄), (p, p̄))
109+
110+
test_rrule(norm, x, p)
111+
= rand_tangent(norm(x, p))
120112
@test extern(rrule(norm, x, p)[2](ȳ)[2]) isa typeof(x)
121113
end
122114
@testset "norm(x::$T, p)" for T in (Float64, ComplexF64)
123115
@testset "p = $p" for p in (-1.0, 2.0, 2.5)
124-
x = randn(T)
125-
y = norm(x, p)
126-
ẋ, ṗ = rand_tangent.((x, p))
127-
x̄, p̄, ȳ = rand_tangent.((x, p, y))
128-
frule_test(norm, (x, ẋ), (p, ṗ))
129-
rrule_test(norm, ȳ, (x, x̄), (p, p̄))
130-
_, back = rrule(norm, x, p)
116+
test_frule(norm, randn(T), p)
117+
test_rrule(norm, randn(T), p)
118+
119+
_, back = rrule(norm, randn(T), p)
131120
@test back(Zero()) == (NO_FIELDS, Zero(), Zero())
132121
end
133122
@testset "p = 0" begin
@@ -149,23 +138,14 @@ end
149138

150139
@testset "normalize" begin
151140
@testset "x::Vector{$T}" for T in (Float64, ComplexF64)
152-
n = 3
153-
x = randn(T, n)
154-
y = normalize(x)
155-
= rand_tangent(x)
156-
= rand_tangent(y)
157-
rrule_test(normalize, ȳ, (x, x̄))
141+
x = randn(T, 3)
142+
test_rrule(normalize, x)
158143
@test rrule(normalize, x)[2](Zero()) === (NO_FIELDS, Zero())
159144
end
160145
@testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64),
161146
p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable
162-
n = 3
163-
x = randn(T, n)
164-
y = normalize(x, p)
165-
= rand_tangent(x)
166-
= rand_tangent(y)
167-
= rand_tangent(p)
168-
rrule_test(normalize, ȳ, (x, x̄), (p, p̄))
147+
x = randn(T, 3)
148+
test_rrule(normalize, x, p)
169149
@test rrule(normalize, x, p)[2](Zero()) === (NO_FIELDS, Zero(), Zero())
170150
end
171151
end

0 commit comments

Comments
 (0)