Skip to content

Commit f88d9f1

Browse files
committed
WIP: norm.jl autotangent
1 parent dbc8134 commit f88d9f1

File tree

1 file changed

+21
-27
lines changed

1 file changed

+21
-27
lines changed

test/rulesets/LinearAlgebra/norm.jl

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,16 @@
2323
kwargs = NamedTuple()
2424
end
2525

26-
y = fnorm(x)
27-
= rand_tangent(x)
28-
= rand_tangent(y)
29-
3026
fnorm === LinearAlgebra.norm2 && @testset "frule" begin
31-
= rand_tangent(x)
32-
frule_test(fnorm, (x, ẋ))
27+
test_frule(fnorm, x)
3328
end
3429
@testset "rrule" begin
35-
rrule_test(fnorm, ȳ, (x, x̄); kwargs...)
30+
test_rrule(fnorm, x; kwargs...)
3631
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
37-
rrule_test(fnorm, ȳ, (MT(x), MT(x̄)); kwargs...)
32+
test_rrule(fnorm, MT(x); kwargs...)
3833
end
34+
35+
= rand_tangent(fnorm(x))
3936
@test extern(rrule(fnorm, zero(x))[2](ȳ)[2]) zero(x)
4037
@test rrule(fnorm, x)[2](Zero())[2] isa Zero
4138
end
@@ -45,23 +42,23 @@
4542
sz in [(0,), (3,), (3, 3), (3, 2, 1)]
4643

4744
x = randn(T, sz)
48-
y = norm(x)
49-
= rand_tangent(x)
50-
= rand_tangent(x)
51-
= rand_tangent(y)
5245

5346
@testset "frule" begin
54-
frule_test(norm, (x, ẋ))
47+
test_frule(norm, x)
5548
@test frule((Zero(), Zero()), norm, x)[2] isa Zero
49+
50+
= rand_tangent(x)
5651
@test iszero(frule((Zero(), ẋ), norm, zero(x))[2])
5752
end
5853
@testset "rrule" begin
59-
rrule_test(norm, ȳ, (x, x̄))
54+
test_rrule(norm, x)
6055
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
6156
# we don't check inference on older julia versions. Improvements to
6257
# inference mean on 1.5+ it works, and that is good enough
63-
rrule_test(norm, ȳ, (MT(x), MT(x̄)); check_inferred=VERSION>=v"1.5")
58+
test_rrule(norm, MT(x); check_inferred=VERSION>=v"1.5")
6459
end
60+
61+
= rand_tangent(norm(x))
6562
@test extern(rrule(norm, zero(x))[2](ȳ)[2]) zero(x)
6663
@test rrule(norm, x)[2](Zero())[2] isa Zero
6764
end
@@ -90,19 +87,16 @@
9087
kwargs = NamedTuple()
9188
end
9289

93-
y = fnorm(x, p)
94-
= rand_tangent(x)
95-
= rand_tangent(y)
96-
= rand_tangent(p)
9790

98-
rrule_test(fnorm, ȳ, (x, x̄), (p, p̄); kwargs...)
91+
test_rrule(fnorm, x, p; kwargs...)
9992
x isa Matrix && @testset "$MT" for MT in (Diagonal, UpperTriangular, LowerTriangular)
100-
rrule_test(
101-
fnorm, ȳ, (MT(x), MT(x̄)), (p, p̄);
93+
test_rrule(fnorm, MT(x), p;
10294
#Don't check inference on old julia, what matters is that works on new
10395
check_inferred=VERSION>=v"1.5", kwargs...
10496
)
10597
end
98+
99+
= rand_tangent(fnorm(x))
106100
@test extern(rrule(fnorm, zero(x), p)[2](ȳ)[2]) zero(x)
107101
@test rrule(fnorm, x, p)[2](Zero())[2] isa Zero
108102
end
@@ -116,7 +110,7 @@
116110
= rand_tangent(x)
117111
= rand_tangent(y)
118112
= rand_tangent(p)
119-
rrule_test(norm, ȳ, (x, x̄), (p, p̄))
113+
test_rrule(norm, ȳ, (x, x̄), (p, p̄))
120114
@test extern(rrule(norm, x, p)[2](ȳ)[2]) isa typeof(x)
121115
end
122116
@testset "norm(x::$T, p)" for T in (Float64, ComplexF64)
@@ -125,8 +119,8 @@
125119
y = norm(x, p)
126120
ẋ, ṗ = rand_tangent.((x, p))
127121
x̄, p̄, ȳ = rand_tangent.((x, p, y))
128-
frule_test(norm, (x, ẋ), (p, ṗ))
129-
rrule_test(norm, ȳ, (x, x̄), (p, p̄))
122+
test_frule(norm, (x, ẋ), (p, ṗ))
123+
test_rrule(norm, ȳ, (x, x̄), (p, p̄))
130124
_, back = rrule(norm, x, p)
131125
@test back(Zero()) == (NO_FIELDS, Zero(), Zero())
132126
end
@@ -154,7 +148,7 @@ end
154148
y = normalize(x)
155149
= rand_tangent(x)
156150
= rand_tangent(y)
157-
rrule_test(normalize, ȳ, (x, x̄))
151+
test_rrule(normalize, ȳ, (x, x̄))
158152
@test rrule(normalize, x)[2](Zero()) === (NO_FIELDS, Zero())
159153
end
160154
@testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64),
@@ -165,7 +159,7 @@ end
165159
= rand_tangent(x)
166160
= rand_tangent(y)
167161
= rand_tangent(p)
168-
rrule_test(normalize, ȳ, (x, x̄), (p, p̄))
162+
test_rrule(normalize, ȳ, (x, x̄), (p, p̄))
169163
@test rrule(normalize, x, p)[2](Zero()) === (NO_FIELDS, Zero(), Zero())
170164
end
171165
end

0 commit comments

Comments
 (0)