Skip to content

Commit 5bb5d0e

Browse files
committed
Write isapprox in terms of isapprox of the factors
1 parent 7559a54 commit 5bb5d0e

File tree

3 files changed

+57
-16
lines changed

3 files changed

+57
-16
lines changed

src/kroneckerarray.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -378,26 +378,29 @@ end
378378

379379
using LinearAlgebra: dot, promote_leaf_eltypes
380380
function Base.isapprox(
381-
a::AbstractKroneckerArray, b::AbstractKroneckerArray;
382-
atol::Real = 0,
381+
a::AbstractKroneckerArray, b::AbstractKroneckerArray; atol::Real = 0,
383382
rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol),
384-
norm::Function = norm
385383
)
386384
a1, a2 = arg1(a), arg2(a)
387385
b1, b2 = arg1(b), arg2(b)
388-
d = if a1 == b1
389-
norm(a1) * norm(a2 - b2)
386+
if a1 == b1
387+
return isapprox(a2, b2; atol = atol / norm(a1), rtol)
390388
elseif a2 == b2
391-
norm(a1 - b1) * norm(b2)
389+
return isapprox(a1, b1; atol = atol / norm(a2), rtol)
392390
else
393-
# This could be defined as `KroneckerArrays.dist_kronecker(a, b)`, but that might have
394-
# numerical precision issues so for now we just error.
395-
error(
396-
"`isapprox` not implemented for KroneckerArrays where both arguments differ. " *
397-
"In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`."
391+
# This could be defined as:
392+
# ```julia
393+
# d = KroneckerArrays.dist_kronecker(a, b)
394+
# iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b)))
395+
# ```
396+
# but that might have numerical precision issues so for now we just error.
397+
throw(
398+
ArgumentError(
399+
"`isapprox` not implemented for KroneckerArrays where both arguments differ. " *
400+
"In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`."
401+
)
398402
)
399403
end
400-
return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b)))
401404
end
402405

403406
function Base.iszero(a::AbstractKroneckerArray)

src/linearalgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function LinearAlgebra.tr(a::AbstractKroneckerArray)
5858
end
5959

6060
using LinearAlgebra: norm
61-
function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Int = 2)
61+
function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Real = 2)
6262
return norm(arg1(a), p) * norm(arg2(a), p)
6363
end
6464

test/test_basics.jl

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,47 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
237237
end
238238
end
239239

240-
# KroneckerArrays.dist
240+
# isapprox
241+
242+
rng = StableRNG(123)
243+
a1 = randn(rng, elt, (2, 2))
244+
a = a1 randn(rng, elt, (3, 3))
245+
b = a1 randn(rng, elt, (3, 3))
246+
@test isapprox(a, b; atol = norm(a - b) * (1 + 2eps(real(elt))))
247+
@test !isapprox(a, b; atol = norm(a - b) * (1 - 2eps(real(elt))))
248+
@test isapprox(
249+
a, b;
250+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
251+
)
252+
@test !isapprox(
253+
a, b;
254+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
255+
)
256+
@test isapprox(
257+
a, b; atol = norm(a - b) * (1 + 2eps(real(elt))),
258+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
259+
)
260+
@test isapprox(
261+
a, b; atol = norm(a - b) * (1 + 2eps(real(elt))),
262+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
263+
)
264+
@test isapprox(
265+
a, b; atol = norm(a - b) * (1 - 2eps(real(elt))),
266+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
267+
)
268+
@test !isapprox(
269+
a, b; atol = norm(a - b) * (1 - 2eps(real(elt))),
270+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
271+
)
272+
273+
a = randn(elt, (2, 2)) randn(elt, (3, 3))
274+
b = randn(elt, (2, 2)) randn(elt, (3, 3))
275+
@test_throws ArgumentError isapprox(a, b)
276+
277+
# KroneckerArrays.dist_kronecker
241278
rng = StableRNG(123)
242-
a = randn(rng, 100, 100) randn(rng, 100, 100)
243-
b = (arg1(a) + 1.0e-1 * randn(rng, size(arg1(a)))) (arg2(a) + 1.0e-1 * randn(rng, size(arg2(a))))
279+
a = randn(rng, (100, 100)) randn(rng, (100, 100))
280+
b = (arg1(a) + randn(rng, size(arg1(a))) / 10)
281+
(arg2(a) + randn(rng, size(arg2(a))) / 10)
244282
@test KroneckerArrays.dist_kronecker(a, b) norm(collect(a) - collect(b)) rtol = 1.0e-2
245283
end

0 commit comments

Comments
 (0)