Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KroneckerArrays"
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.4"
version = "0.1.5"

[deps]
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
Expand Down
184 changes: 168 additions & 16 deletions src/KroneckerArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -582,9 +582,11 @@ using MatrixAlgebraKit:
eigh_full,
qr_compact,
qr_full,
left_orth,
left_polar,
lq_compact,
lq_full,
right_orth,
right_polar,
svd_compact,
svd_full
Expand All @@ -608,12 +610,22 @@ for f in [
end
end

for f in (:eig, :eigh, :lq, :qr, :polar, :svd)
ff = Symbol("default_", f, "_algorithm")
for f in [
:default_eig_algorithm,
:default_eigh_algorithm,
:default_lq_algorithm,
:default_qr_algorithm,
:default_polar_algorithm,
:default_svd_algorithm,
]
@eval begin
function MatrixAlgebraKit.$ff(A::Type{<:KroneckerMatrix}; kwargs...)
function MatrixAlgebraKit.$f(
A::Type{<:KroneckerMatrix}; kwargs1=(;), kwargs2=(;), kwargs...
)
A1, A2 = argument_types(A)
return KroneckerAlgorithm($ff(A1; kwargs...), $ff(A2; kwargs...))
return KroneckerAlgorithm(
$f(A1; kwargs..., kwargs1...), $f(A2; kwargs..., kwargs2...)
)
end
end
end
Expand All @@ -631,7 +643,7 @@ function MatrixAlgebraKit.default_algorithm(
return default_qr_algorithm(A; kwargs...)
end

for f in (
for f in [
:eig_full!,
:eigh_full!,
:qr_compact!,
Expand All @@ -642,22 +654,24 @@ for f in (
:right_polar!,
:svd_compact!,
:svd_full!,
)
]
@eval begin
function MatrixAlgebraKit.initialize_output(
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
)
return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b)
end
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...)
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...)
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...)
function MatrixAlgebraKit.$f(
a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
)
$f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...)
$f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...)
return F
end
end
end

for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
for f in [:eig_vals!, :eigh_vals!, :svd_vals!]
@eval begin
function MatrixAlgebraKit.initialize_output(
::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm
Expand All @@ -672,7 +686,7 @@ for f in (:eig_vals!, :eigh_vals!, :svd_vals!)
end
end

for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!]
@eval begin
function MatrixAlgebraKit.truncate!(
::typeof($f),
Expand All @@ -684,25 +698,163 @@ for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!)
end
end

for f in (:left_orth!, :right_orth!)
for f in [:left_orth!, :right_orth!]
@eval begin
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
return initialize_output($f, a.a) .⊗ initialize_output($f, a.b)
end
end
end

for f in (:left_null!, :right_null!)
for f in [:left_null!, :right_null!]
@eval begin
function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix)
return initialize_output($f, a.a) ⊗ initialize_output($f, a.b)
end
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...)
$f(a.a, F.a; kwargs...)
$f(a.b, F.b; kwargs...)
function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs1=(;), kwargs2=(;), kwargs...)
$f(a.a, F.a; kwargs..., kwargs1...)
$f(a.b, F.b; kwargs..., kwargs2...)
return F
end
end
end

####################################################################################
# Special cases for MatrixAlgebraKit factorizations of `Eye(n) ⊗ A` and
# `A ⊗ Eye(n)` where `A`.
# TODO: Delete this once https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/34
# is merged.

using FillArrays: SquareEye
const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}
const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B}

struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm
kwargs::KWargs
end
SquareEyeAlgorithm(; kwargs...) = SquareEyeAlgorithm((; kwargs...))

# Defined to avoid type piracy.
_copy_input_squareeye(f::F, a) where {F} = copy_input(f, a)
_copy_input_squareeye(f::F, a::SquareEye) where {F} = a

for f in [
:eig_full,
:eigh_full,
:qr_compact,
:qr_full,
:left_orth,
:left_polar,
:lq_compact,
:lq_full,
:right_orth,
:right_polar,
:svd_compact,
:svd_full,
]
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
@eval begin
function MatrixAlgebraKit.copy_input(::typeof($f), a::$T)
return _copy_input_squareeye($f, a.a) ⊗ _copy_input_squareeye($f, a.b)
end
end
end
end

for f in [
:default_eig_algorithm,
:default_eigh_algorithm,
:default_lq_algorithm,
:default_qr_algorithm,
:default_polar_algorithm,
:default_svd_algorithm,
]
f′ = Symbol("_", f, "_squareeye")
@eval begin
$f′(a; kwargs...) = $f(a; kwargs...)
$f′(a::Type{<:SquareEye}; kwargs...) = SquareEyeAlgorithm(; kwargs...)
end
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
@eval begin
function MatrixAlgebraKit.$f(A::Type{<:$T}; kwargs1=(;), kwargs2=(;), kwargs...)
A1, A2 = argument_types(A)
return KroneckerAlgorithm(
$f′(A1; kwargs..., kwargs1...), $f′(A2; kwargs..., kwargs2...)
)
end
end
end
end

# Defined to avoid type piracy.
_initialize_output_squareeye(f::F, a) where {F} = initialize_output(f, a)
_initialize_output_squareeye(f::F, a, alg) where {F} = initialize_output(f, a, alg)

for f in [
:eig_full!,
:eigh_full!,
:qr_compact!,
:qr_full!,
:left_orth!,
:left_polar!,
:lq_compact!,
:lq_full!,
:right_orth!,
:right_polar!,
]
@eval begin
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a)
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a)
end
end
for f in [:svd_compact!, :svd_full!]
@eval begin
_initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a)
_initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a)
end
end

for f in [
:eig_full!,
:eigh_full!,
:qr_compact!,
:qr_full!,
:left_orth!,
:left_polar!,
:lq_compact!,
:lq_full!,
:right_orth!,
:right_polar!,
:svd_compact!,
:svd_full!,
]
f′ = Symbol("_", f, "_squareeye")
@eval begin
$f′(a, F, alg; kwargs...) = $f(a, F, alg; kwargs...)
$f′(a, F, alg::SquareEyeAlgorithm) = F
end
for T in [:SquareEyeKronecker, :KroneckerSquareEye, :SquareEyeSquareEye]
@eval begin
function MatrixAlgebraKit.initialize_output(::typeof($f), a::$T)
return _initialize_output_squareeye($f, a.a) .⊗
_initialize_output_squareeye($f, a.b)
end
function MatrixAlgebraKit.initialize_output(
::typeof($f), a::$T, alg::KroneckerAlgorithm
)
return _initialize_output_squareeye($f, a.a, alg.a) .⊗
_initialize_output_squareeye($f, a.b, alg.b)
end
function MatrixAlgebraKit.$f(
a::$T, F, alg::KroneckerAlgorithm; kwargs1=(;), kwargs2=(;), kwargs...
)
$f′(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs..., kwargs1...)
$f′(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs..., kwargs2...)
return F
end
end
end
end

end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Aqua = "0.8"
Expand All @@ -17,3 +18,4 @@ MatrixAlgebraKit = "0.2"
SafeTestsets = "0.1"
Suppressor = "0.2"
Test = "1.10"
TestExtras = "0.3"
91 changes: 89 additions & 2 deletions test/test_matrixalgebrakit.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using KroneckerArrays: ⊗
using FillArrays: Eye
using KroneckerArrays: ⊗, arguments
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
using MatrixAlgebraKit:
eig_full,
Expand All @@ -22,8 +23,9 @@ using MatrixAlgebraKit:
svd_trunc,
svd_vals
using Test: @test, @test_throws, @testset
using TestExtras: @constinferred

herm(a) = hermitianpart(a).data
herm(a) = parent(hermitianpart(a))

@testset "MatrixAlgebraKit" begin
elt = Float32
Expand Down Expand Up @@ -117,3 +119,88 @@ herm(a) = hermitianpart(a).data
s = svd_vals(a)
@test s ≈ diag(svd_compact(a)[2])
end

@testset "MatrixAlgebraKit + Eye" begin

# TODO:
# eig_trunc
# eig_vals
# eigh_trunc
# eigh_vals
# left_null
# right_null
# svd_trunc
# svd_vals

for f in (eig_full, eigh_full)
a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3)))
d, v = @constinferred f(a)
@test a * v ≈ v * d
@test arguments(d, 1) isa Eye
@test arguments(v, 1) isa Eye

a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3)
d, v = @constinferred f(a)
@test a * v ≈ v * d
@test arguments(d, 2) isa Eye
@test arguments(v, 2) isa Eye

a = Eye(3) ⊗ Eye(3)
d, v = @constinferred f(a)
@test a * v ≈ v * d
@test arguments(d, 1) isa Eye
@test arguments(d, 2) isa Eye
@test arguments(v, 1) isa Eye
@test arguments(v, 2) isa Eye
end

for f in (
left_orth, left_polar, lq_compact, lq_full, qr_compact, qr_full, right_orth, right_polar
)
a = Eye(3) ⊗ randn(3, 3)
x, y = f(a)
@test x * y ≈ a
@test arguments(x, 1) isa Eye
@test arguments(y, 1) isa Eye

a = randn(3, 3) ⊗ Eye(3)
x, y = f(a)
@test x * y ≈ a
@test arguments(x, 2) isa Eye
@test arguments(y, 2) isa Eye

a = Eye(3) ⊗ Eye(3)
x, y = f(a)
@test x * y ≈ a
@test arguments(x, 1) isa Eye
@test arguments(y, 1) isa Eye
@test arguments(x, 2) isa Eye
@test arguments(y, 2) isa Eye
end

for f in (svd_compact, svd_full)
a = Eye(3) ⊗ randn(3, 3)
u, s, v = f(a)
@test u * s * v ≈ a
@test arguments(u, 1) isa Eye
@test arguments(s, 1) isa Eye
@test arguments(v, 1) isa Eye

a = randn(3, 3) ⊗ Eye(3)
u, s, v = f(a)
@test u * s * v ≈ a
@test arguments(u, 2) isa Eye
@test arguments(s, 2) isa Eye
@test arguments(v, 2) isa Eye

a = Eye(3) ⊗ Eye(3)
u, s, v = f(a)
@test u * s * v ≈ a
@test arguments(u, 1) isa Eye
@test arguments(s, 1) isa Eye
@test arguments(v, 1) isa Eye
@test arguments(u, 2) isa Eye
@test arguments(s, 2) isa Eye
@test arguments(v, 2) isa Eye
end
end
Loading