Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
46 changes: 42 additions & 4 deletions src/kroneckerarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,49 @@ function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
end

# TODO: this definition doesn't fully retain the original meaning:
# ‖a - b‖ < atol could be true even if the following check isn't
function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...)
return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...)
# norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2)
# = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2))
function dist_kronecker(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
a1, a2 = arg1(a), arg2(a)
b1, b2 = arg1(b), arg2(b)
diff1 = a1 - b1
diff2 = a2 - b2
# x = (a1 - b1) ⊗ a2
# y = b1 ⊗ (a2 - b2)
# z = (a1 - b1) ⊗ (a2 - b2)
xx = norm(diff1)^2 * norm(a2)^2
yy = norm(b1)^2 * norm(diff2)^2
zz = norm(diff1)^2 * norm(diff2)^2
xy = real(dot(diff1, b1) * dot(a2, diff2))
xz = real(dot(diff1, diff1) * dot(a2, diff2))
yz = real(dot(b1, diff1) * dot(diff2, diff2))
return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz)))
end

using LinearAlgebra: dot, promote_leaf_eltypes
function Base.isapprox(
a::AbstractKroneckerArray, b::AbstractKroneckerArray;
atol::Real = 0,
rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol),
norm::Function = norm
)
a1, a2 = arg1(a), arg2(a)
b1, b2 = arg1(b), arg2(b)
d = if a1 == b1
norm(a1) * norm(a2 - b2)
elseif a2 == b2
norm(a1 - b1) * norm(b2)
else
# This could be defined as `KroneckerArrays.dist_kronecker(a, b)`, but that might have
# numerical precision issues so for now we just error.
error(
"`isapprox` not implemented for KroneckerArrays where both arguments differ. " *
"In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`."
)
end
return iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b)))
end

function Base.iszero(a::AbstractKroneckerArray)
return iszero(arg1(a)) || iszero(arg2(a))
end
Expand Down
28 changes: 12 additions & 16 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,9 @@ using DerivableInterfaces: zero!
using DiagonalArrays: diagonal
using GPUArraysCore: @allowscalar
using JLArrays: JLArray
using KroneckerArrays:
KroneckerArrays,
KroneckerArray,
KroneckerStyle,
CartesianProductUnitRange,
CartesianProductVector,
⊗,
×,
arg1,
arg2,
cartesianproduct,
cartesianrange,
kron_nd,
unproduct
using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle,
CartesianProductUnitRange, CartesianProductVector, ⊗, ×, arg1, arg2, cartesianproduct,
cartesianrange, kron_nd, unproduct
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr
using StableRNGs: StableRNG
using Test: @test, @test_broken, @test_throws, @testset
Expand Down Expand Up @@ -219,10 +208,11 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c = a.arg1 ⊗ b.arg2
c = arg1(a) ⊗ arg2(b)
U, S, V = svd(a)
@test collect(U * diagonal(S) * V') ≈ collect(a)
@test svdvals(a) ≈ S
@test arg1(svdvals(a)) ≈ arg1(S)
@test arg2(svdvals(a)) ≈ arg2(S)
@test sort(collect(S); rev = true) ≈ svdvals(collect(a))
@test collect(U'U) ≈ I
@test collect(V * V') ≈ I
Expand All @@ -246,4 +236,10 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
@test_throws ArgumentError $f($a)
end
end

# KroneckerArrays.dist
rng = StableRNG(123)
a = randn(rng, 100, 100) ⊗ randn(rng, 100, 100)
b = (arg1(a) + 1.0e-1 * randn(rng, size(arg1(a)))) ⊗ (arg2(a) + 1.0e-1 * randn(rng, size(arg2(a))))
@test KroneckerArrays.dist_kronecker(a, b) ≈ norm(collect(a) - collect(b)) rtol = 1.0e-2
end
83 changes: 48 additions & 35 deletions test/test_matrixalgebrakit.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,8 @@
using KroneckerArrays: ⊗, arguments
using KroneckerArrays: ⊗, arg1, arg2
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
using MatrixAlgebraKit:
eig_full,
eig_trunc,
eig_vals,
eigh_full,
eigh_trunc,
eigh_vals,
left_null,
left_orth,
left_polar,
lq_compact,
lq_full,
qr_compact,
qr_full,
right_null,
right_orth,
right_polar,
svd_compact,
svd_full,
svd_trunc,
using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc,
eigh_vals, left_null, left_orth, left_polar, lq_compact, lq_full, qr_compact,
qr_full, right_null, right_orth, right_polar, svd_compact, svd_full, svd_trunc,
svd_vals
using Test: @test, @test_throws, @testset
using TestExtras: @constinferred
Expand All @@ -31,18 +14,26 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
d, v = eig_full(a)
@test a * v ≈ v * d
av = a * v
vd = v * d
@test arg1(av) ≈ arg1(vd)
@test arg2(av) ≈ arg2(vd)

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
@test_throws ArgumentError eig_trunc(a)

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
d = eig_vals(a)
@test d ≈ diag(eig_full(a)[1])
d′ = diag(eig_full(a)[1])
@test arg1(d) ≈ arg1(d′)
@test arg2(d) ≈ arg2(d′)

a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3))
d, v = eigh_full(a)
@test a * v ≈ v * d
av = a * v
vd = v * d
@test arg1(av) ≈ arg1(vd)
@test arg2(av) ≈ arg2(vd)
@test eltype(d) === real(elt)
@test eltype(v) === elt

Expand All @@ -56,22 +47,30 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = qr_compact(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = qr_full(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = lq_compact(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = lq_full(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3)
Expand All @@ -84,27 +83,37 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = left_orth(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = right_orth(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, c = left_polar(a)
@test u * c ≈ a
uc = u * c
@test arg1(uc) ≈ arg1(a)
@test arg2(uc) ≈ arg2(a)
@test collect(u'u) ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
c, u = right_polar(a)
@test c * u ≈ a
cu = c * u
@test arg1(cu) ≈ arg1(a)
@test arg2(cu) ≈ arg2(a)
@test collect(u * u') ≈ I

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, s, v = svd_compact(a)
@test u * s * v ≈ a
usv = u * s * v
@test arg1(usv) ≈ arg1(a)
@test arg2(usv) ≈ arg2(a)
@test eltype(u) === elt
@test eltype(s) === real(elt)
@test eltype(v) === elt
Expand All @@ -113,7 +122,9 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
u, s, v = svd_full(a)
@test u * s * v ≈ a
usv = u * s * v
@test arg1(usv) ≈ arg1(a)
@test arg2(usv) ≈ arg2(a)
@test eltype(u) === elt
@test eltype(s) === real(elt)
@test eltype(v) === elt
Expand All @@ -125,5 +136,7 @@ herm(a) = parent(hermitianpart(a))

a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3)
s = svd_vals(a)
@test s ≈ diag(svd_compact(a)[2])
s′ = diag(svd_compact(a)[2])
@test arg1(s) ≈ arg1(s′)
@test arg2(s) ≈ arg2(s′)
end
Loading