Skip to content

Commit 61e2420

Browse files
committed
Stricter isapprox, update tests
1 parent 451cd32 commit 61e2420

File tree

3 files changed

+88
-56
lines changed

3 files changed

+88
-56
lines changed

src/kroneckerarray.jl

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,25 @@ function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
357357
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
358358
end
359359

360+
# norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2)
361+
# = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2))
362+
function dist(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
363+
a1, a2 = arg1(a), arg2(a)
364+
b1, b2 = arg1(b), arg2(b)
365+
diff1 = a1 - b1
366+
diff2 = a2 - b2
367+
# x = (a1 - b1) ⊗ a2
368+
# y = b1 ⊗ (a2 - b2)
369+
# z = (a1 - b1) ⊗ (a2 - b2)
370+
xx = norm(diff1)^2 * norm(a2)^2
371+
yy = norm(b1)^2 * norm(diff2)^2
372+
zz = norm(diff1)^2 * norm(diff2)^2
373+
xy = real(dot(diff1, b1) * dot(a2, diff2))
374+
xz = real(dot(diff1, diff1) * dot(a2, diff2))
375+
yz = real(dot(b1, diff1) * dot(diff2, diff2))
376+
return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz)))
377+
end
378+
360379
using LinearAlgebra: dot, promote_leaf_eltypes
361380
function Base.isapprox(
362381
a::AbstractKroneckerArray, b::AbstractKroneckerArray;
@@ -366,12 +385,18 @@ function Base.isapprox(
366385
)
367386
a1, a2 = arg1(a), arg2(a)
368387
b1, b2 = arg1(b), arg2(b)
369-
# Approximation of:
370-
# norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2)
371-
# = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2))
372-
diff1 = a1 - b1
373-
diff2 = a2 - b2
374-
d = sqrt(norm(diff1)^2 * norm(a2)^2 + norm(b1)^2 * norm(diff2)^2 + 2 * real(dot(diff1, b1) * dot(b2, diff2)))
388+
d = if a1 == b1
389+
norm(b1) * norm(a2 - b2)
390+
elseif a2 == b2
391+
norm(a1 - b1) * norm(b2)
392+
else
393+
# This could be defined as `KroneckerArrays.dist(a, b)`, but that might have
394+
# numerical precision issues so for now we just error.
395+
error(
396+
"`isapprox` not implemented for KroneckerArrays where both arguments differ. " *
397+
"In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`."
398+
)
399+
end
375400
return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b)))
376401
end
377402

test/test_basics.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,9 @@ using DerivableInterfaces: zero!
44
using DiagonalArrays: diagonal
55
using GPUArraysCore: @allowscalar
66
using JLArrays: JLArray
7-
using KroneckerArrays:
8-
KroneckerArrays,
9-
KroneckerArray,
10-
KroneckerStyle,
11-
CartesianProductUnitRange,
12-
CartesianProductVector,
13-
,
14-
×,
15-
arg1,
16-
arg2,
17-
cartesianproduct,
18-
cartesianrange,
19-
kron_nd,
20-
unproduct
7+
using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle,
8+
CartesianProductUnitRange, CartesianProductVector, , ×, arg1, arg2, cartesianproduct,
9+
cartesianrange, kron_nd, unproduct
2110
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr
2211
using StableRNGs: StableRNG
2312
using Test: @test, @test_broken, @test_throws, @testset
@@ -219,10 +208,11 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
219208

220209
a = randn(elt, 2, 2) randn(elt, 3, 3)
221210
b = randn(elt, 2, 2) randn(elt, 3, 3)
222-
c = a.arg1 b.arg2
211+
c = arg1(a) arg2(b)
223212
U, S, V = svd(a)
224213
@test collect(U * diagonal(S) * V') collect(a)
225-
@test svdvals(a) S
214+
@test arg1(svdvals(a)) arg1(S)
215+
@test arg2(svdvals(a)) arg2(S)
226216
@test sort(collect(S); rev = true) svdvals(collect(a))
227217
@test collect(U'U) I
228218
@test collect(V * V') I
@@ -246,4 +236,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
246236
@test_throws ArgumentError $f($a)
247237
end
248238
end
239+
240+
# KroneckerArrays.dist
241+
rng = StableRNG(123)
242+
a = randn(rng, 100, 100) randn(rng, 100, 100)
243+
b = (arg1(a) + 1.0e-1 * randn(rng, size(arg1(a)))) (arg2(a) + 1.0e-1 * randn(rng, size(arg2(a))))
244+
@test KroneckerArrays.dist(a, b) norm(collect(a) - collect(b)) rtol = 1.0e-2
249245
end

test/test_matrixalgebrakit.jl

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,8 @@
1-
using KroneckerArrays: , arguments
1+
using KroneckerArrays: , arg1, arg2
22
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
3-
using MatrixAlgebraKit:
4-
eig_full,
5-
eig_trunc,
6-
eig_vals,
7-
eigh_full,
8-
eigh_trunc,
9-
eigh_vals,
10-
left_null,
11-
left_orth,
12-
left_polar,
13-
lq_compact,
14-
lq_full,
15-
qr_compact,
16-
qr_full,
17-
right_null,
18-
right_orth,
19-
right_polar,
20-
svd_compact,
21-
svd_full,
22-
svd_trunc,
3+
using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc,
4+
eigh_vals, left_null, left_orth, left_polar, lq_compact, lq_full, qr_compact,
5+
qr_full, right_null, right_orth, right_polar, svd_compact, svd_full, svd_trunc,
236
svd_vals
247
using Test: @test, @test_throws, @testset
258
using TestExtras: @constinferred
@@ -31,18 +14,26 @@ herm(a) = parent(hermitianpart(a))
3114

3215
a = randn(elt, 2, 2) randn(elt, 3, 3)
3316
d, v = eig_full(a)
34-
@test a * v v * d
17+
av = a * v
18+
vd = v * d
19+
@test arg1(av) arg1(vd)
20+
@test arg2(av) arg2(vd)
3521

3622
a = randn(elt, 2, 2) randn(elt, 3, 3)
3723
@test_throws ArgumentError eig_trunc(a)
3824

3925
a = randn(elt, 2, 2) randn(elt, 3, 3)
4026
d = eig_vals(a)
41-
@test d diag(eig_full(a)[1])
27+
d′ = diag(eig_full(a)[1])
28+
@test arg1(d) arg1(d′)
29+
@test arg2(d) arg2(d′)
4230

4331
a = herm(randn(elt, 2, 2)) herm(randn(elt, 3, 3))
4432
d, v = eigh_full(a)
45-
@test a * v v * d
33+
av = a * v
34+
vd = v * d
35+
@test arg1(av) arg1(vd)
36+
@test arg2(av) arg2(vd)
4637
@test eltype(d) === real(elt)
4738
@test eltype(v) === elt
4839

@@ -56,22 +47,30 @@ herm(a) = parent(hermitianpart(a))
5647

5748
a = randn(elt, 2, 2) randn(elt, 3, 3)
5849
u, c = qr_compact(a)
59-
@test u * c a
50+
uc = u * c
51+
@test arg1(uc) arg1(a)
52+
@test arg2(uc) arg2(a)
6053
@test collect(u'u) I
6154

6255
a = randn(elt, 2, 2) randn(elt, 3, 3)
6356
u, c = qr_full(a)
64-
@test u * c a
57+
uc = u * c
58+
@test arg1(uc) arg1(a)
59+
@test arg2(uc) arg2(a)
6560
@test collect(u'u) I
6661

6762
a = randn(elt, 2, 2) randn(elt, 3, 3)
6863
c, u = lq_compact(a)
69-
@test c * u a
64+
cu = c * u
65+
@test arg1(cu) arg1(a)
66+
@test arg2(cu) arg2(a)
7067
@test collect(u * u') I
7168

7269
a = randn(elt, 2, 2) randn(elt, 3, 3)
7370
c, u = lq_full(a)
74-
@test c * u a
71+
cu = c * u
72+
@test arg1(cu) arg1(a)
73+
@test arg2(cu) arg2(a)
7574
@test collect(u * u') I
7675

7776
a = randn(elt, 3, 2) randn(elt, 4, 3)
@@ -84,27 +83,37 @@ herm(a) = parent(hermitianpart(a))
8483

8584
a = randn(elt, 2, 2) randn(elt, 3, 3)
8685
u, c = left_orth(a)
87-
@test u * c a
86+
uc = u * c
87+
@test arg1(uc) arg1(a)
88+
@test arg2(uc) arg2(a)
8889
@test collect(u'u) I
8990

9091
a = randn(elt, 2, 2) randn(elt, 3, 3)
9192
c, u = right_orth(a)
92-
@test c * u a
93+
cu = c * u
94+
@test arg1(cu) arg1(a)
95+
@test arg2(cu) arg2(a)
9396
@test collect(u * u') I
9497

9598
a = randn(elt, 2, 2) randn(elt, 3, 3)
9699
u, c = left_polar(a)
97-
@test u * c a
100+
uc = u * c
101+
@test arg1(uc) arg1(a)
102+
@test arg2(uc) arg2(a)
98103
@test collect(u'u) I
99104

100105
a = randn(elt, 2, 2) randn(elt, 3, 3)
101106
c, u = right_polar(a)
102-
@test c * u a
107+
cu = c * u
108+
@test arg1(cu) arg1(a)
109+
@test arg2(cu) arg2(a)
103110
@test collect(u * u') I
104111

105112
a = randn(elt, 2, 2) randn(elt, 3, 3)
106113
u, s, v = svd_compact(a)
107-
@test u * s * v a
114+
usv = u * s * v
115+
@test arg1(usv) arg1(a)
116+
@test arg2(usv) arg2(a)
108117
@test eltype(u) === elt
109118
@test eltype(s) === real(elt)
110119
@test eltype(v) === elt
@@ -113,7 +122,9 @@ herm(a) = parent(hermitianpart(a))
113122

114123
a = randn(elt, 2, 2) randn(elt, 3, 3)
115124
u, s, v = svd_full(a)
116-
@test u * s * v a
125+
usv = u * s * v
126+
@test arg1(usv) arg1(a)
127+
@test arg2(usv) arg2(a)
117128
@test eltype(u) === elt
118129
@test eltype(s) === real(elt)
119130
@test eltype(v) === elt

0 commit comments

Comments
 (0)