From b5c121a7e87e4a36cc41c03dc7986050eb86f5f0 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 28 May 2025 18:34:38 -0400 Subject: [PATCH 1/3] More functionality --- Project.toml | 8 +- src/KroneckerArrays.jl | 204 +++++++++++++++++++++++++++++++++- test/Project.toml | 1 + test/test_matrixalgebrakit.jl | 47 ++++++++ 4 files changed, 253 insertions(+), 7 deletions(-) create mode 100644 test/test_matrixalgebrakit.jl diff --git a/Project.toml b/Project.toml index c612160..8e8934f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,17 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] +DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" [compat] +DerivableInterfaces = "0.4.5" +GPUArraysCore = "0.2.0" LinearAlgebra = "1.10" +MatrixAlgebraKit = "0.2.0" julia = "1.10" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index e91790a..3a5a074 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -1,5 +1,7 @@ module KroneckerArrays +using GPUArraysCore: GPUArraysCore + export ⊗, × struct CartesianProduct{A,B} @@ -28,6 +30,26 @@ end Base.first(r::CartesianProductUnitRange) = first(r.range) Base.last(r::CartesianProductUnitRange) = last(r.range) +function Base.axes(r::CartesianProductUnitRange) + return (CartesianProductUnitRange(r.product, only(axes(r.range))),) +end + +using Base.Broadcast: DefaultArrayStyle +for f in (:+, :-) + @eval begin + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer + ) + return CartesianProductUnitRange(r.product, $f.(r.range, x)) + end + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange + ) + return CartesianProductUnitRange(r.product, $f.(x, r.range)) + end + end +end + struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N} a::A b::B @@ -44,6 +66,15 @@ end const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B} const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B} +function Base.copy(a::KroneckerArray) + return copy(a.a) ⊗ copy(a.b) +end +function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) + copyto!(dest.a, src.a) + copyto!(dest.b, src.b) + return dest +end + function Base.similar( a::AbstractArray, elt::Type, @@ -73,9 +104,21 @@ function Base.similar( return similar(arrayt, map(ax -> ax.product.a, axs)) ⊗ similar(arrayt, map(ax -> ax.product.b, axs)) end +function Base.similar( + arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) where {A,B} + return similar(A, map(ax -> ax.product.a, axs)) ⊗ similar(B, map(ax -> ax.product.b, axs)) +end Base.collect(a::KroneckerArray) = kron(a.a, a.b) +function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} + return convert(Array{T,N}, collect(a)) +end + Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a)) function Base.axes(a::KroneckerArray) @@ -107,12 +150,23 @@ end ⊗(a::Number, b::AbstractVecOrMat) = a * b ⊗(a::AbstractVecOrMat, b::Number) = a * b -function Base.getindex(::KroneckerArray, ::Int) - return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported.")) +function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) + GPUArraysCore.assertscalar("getindex") + # Code logic from Kronecker.jl: + # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 + k, l = size(a.b) + return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] end -function Base.getindex(::KroneckerArray{<:Any,N}, ::Vararg{Int,N}) where {N} - return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported.")) +function Base.getindex(a::KroneckerMatrix, i::Integer) + return a[CartesianIndices(a)[i]] end + +function Base.getindex(a::KroneckerVector, i::Integer) + GPUArraysCore.assertscalar("getindex") + k = length(a.b) + return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] +end + function Base.getindex(a::KroneckerVector, i::CartesianProduct) return a.a[i.a] ⊗ a.b[i.b] end @@ -169,9 +223,18 @@ end function Base.:*(a::KroneckerArray, b::KroneckerArray) return (a.a * b.a) ⊗ (a.b * b.b) end -function LinearAlgebra.mul!(c::KroneckerArray, a::KroneckerArray, b::KroneckerArray) +function LinearAlgebra.mul!( + c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number +) + iszero(β) || + iszero(c) || + throw( + ArgumentError( + "Can't multiple KroneckerArrays with nonzero β and nonzero destination." + ), + ) mul!(c.a, a.a, b.a) - mul!(c.b, a.b, b.b) + mul!(c.b, a.b, b.b, α, β) return c end function LinearAlgebra.tr(a::KroneckerArray) @@ -269,4 +332,133 @@ for op in (:+, :-) end end +function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray) + dest.a .= a.a + dest.b .= a.b + return dest +end +function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray) + if a.b == b.b + map!(+, dest.a, a.a, b.a) + dest.b .= a.b + elseif a.a == b.a + dest.a .= a.a + map!(+, dest.b, a.b, b.b) + else + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + return dest +end +function Base.map!( + f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray +) + dest.a .= f.x .* a.a + dest.b .= a.b + return dest +end +function Base.map!( + f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray +) + dest.a .= a.a + dest.b .= a.b .* f.x + return dest +end + +using DerivableInterfaces: DerivableInterfaces, zero! +function DerivableInterfaces.zero!(a::KroneckerArray) + zero!(a.a) + zero!(a.b) + return a +end + +using MatrixAlgebraKit: + MatrixAlgebraKit, + AbstractAlgorithm, + TruncationStrategy, + default_eig_algorithm, + default_eigh_algorithm, + eig_full!, + eig_trunc!, + eig_vals!, + eigh_full!, + eigh_trunc!, + eigh_vals!, + initialize_output, + truncate! + +struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm + a::A + b::B +end + +function MatrixAlgebraKit.default_eig_algorithm(a::KroneckerMatrix) + return KroneckerAlgorithm(default_eig_algorithm(a.a), default_eig_algorithm(a.b)) +end +function MatrixAlgebraKit.initialize_output( + f::typeof(eig_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm +) + return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b) +end +function MatrixAlgebraKit.eig_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) + eig_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a) + eig_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b) + return F +end + +function MatrixAlgebraKit.truncate!( + ::typeof(eig_trunc!), + (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, + strategy::TruncationStrategy, +) + return throw(MethodError(truncate!, (eig_trunc!, (D, V), strategy))) +end + +function MatrixAlgebraKit.initialize_output( + f::typeof(eig_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm +) + return initialize_output(f, a.a, alg.a) ⊗ initialize_output(f, a.b, alg.b) +end +function MatrixAlgebraKit.eig_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) + eig_vals!(a.a, F.a, alg.a) + eig_vals!(a.b, F.b, alg.b) + return F +end + +function MatrixAlgebraKit.default_eigh_algorithm(a::KroneckerMatrix) + return KroneckerAlgorithm(default_eigh_algorithm(a.a), default_eigh_algorithm(a.b)) +end +function MatrixAlgebraKit.initialize_output( + f::typeof(eigh_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm +) + return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b) +end +function MatrixAlgebraKit.eigh_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) + eigh_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a) + eigh_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b) + return F +end + +function MatrixAlgebraKit.truncate!( + ::typeof(eigh_trunc!), + (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, + strategy::TruncationStrategy, +) + return throw(MethodError(truncate!, (eigh_trunc!, (D, V), strategy))) +end + +function MatrixAlgebraKit.initialize_output( + f::typeof(eigh_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm +) + return initialize_output(f, a.a, alg.a) ⊗ initialize_output(f, a.b, alg.b) +end +function MatrixAlgebraKit.eigh_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) + eigh_vals!(a.a, F.a, alg.a) + eigh_vals!(a.b, F.b, alg.b) + return F +end + end diff --git a/test/Project.toml b/test/Project.toml index 3b27675..772eff2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl new file mode 100644 index 0000000..0f9f303 --- /dev/null +++ b/test/test_matrixalgebrakit.jl @@ -0,0 +1,47 @@ +using KroneckerArrays: ⊗ +using LinearAlgebra: Hermitian, diag +using MatrixAlgebraKit: + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, + left_null, + left_orth, + left_polar, + lq_compact, + lq_full, + qr_compact, + qr_full, + right_null, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_trunc, + svd_vals +using Test: @test, @test_throws, @testset + +@testset "MatrixAlgebraKit" begin + x = randn(2, 2) + y = randn(3, 3) + a = x ⊗ y + ah = Hermitian(x) ⊗ Hermitian(y) + + d, v = eig_full(a) + @test a * v ≈ v * d + + @test_throws MethodError eig_trunc(a) + + d = eig_vals(a) + @test d ≈ diag(eig_full(a)[1]) + + d, v = eigh_full(ah) + @test ah * v ≈ v * d + + @test_throws MethodError eigh_trunc(ah) + + d = eigh_vals(ah) + @test d ≈ diag(eigh_full(ah)[1]) +end From 29965411df2075ac5eb3f80a0f00a8480c7d5e42 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 28 May 2025 18:48:45 -0400 Subject: [PATCH 2/3] Implement left and right null --- src/KroneckerArrays.jl | 32 ++++++++++++++++++++++++++++++++ test/test_matrixalgebrakit.jl | 31 +++++++++++++++++++++---------- 2 files changed, 53 insertions(+), 10 deletions(-) diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 3a5a074..aa64a21 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -381,6 +381,7 @@ using MatrixAlgebraKit: TruncationStrategy, default_eig_algorithm, default_eigh_algorithm, + default_qr_algorithm, eig_full!, eig_trunc!, eig_vals!, @@ -388,6 +389,8 @@ using MatrixAlgebraKit: eigh_trunc!, eigh_vals!, initialize_output, + left_null!, + right_null!, truncate! struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm @@ -461,4 +464,33 @@ function MatrixAlgebraKit.eigh_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgori return F end +function MatrixAlgebraKit.default_qr_algorithm(a::KroneckerMatrix; kwargs...) + return KroneckerAlgorithm( + default_qr_algorithm(a.a; kwargs...), default_qr_algorithm(a.b; kwargs...) + ) +end +function MatrixAlgebraKit.default_lq_algorithm(a::KroneckerMatrix; kwargs...) + return KroneckerAlgorithm( + default_lq_algorithm(a.a; kwargs...), default_lq_algorithm(a.b; kwargs...) + ) +end + +function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::KroneckerMatrix) + return initialize_output(f, a.a) ⊗ initialize_output(f, a.b) +end +function MatrixAlgebraKit.left_null!(a::KroneckerMatrix, F; kwargs...) + left_null!(a.a, F.a; kwargs...) + left_null!(a.b, F.b; kwargs...) + return F +end + +function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::KroneckerMatrix) + return initialize_output(f, a.a) ⊗ initialize_output(f, a.b) +end +function MatrixAlgebraKit.right_null!(a::KroneckerMatrix, F; kwargs...) + right_null!(a.a, F.a; kwargs...) + right_null!(a.b, F.b; kwargs...) + return F +end + end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index 0f9f303..f5b4baf 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,5 +1,5 @@ using KroneckerArrays: ⊗ -using LinearAlgebra: Hermitian, diag +using LinearAlgebra: Hermitian, diag, norm using MatrixAlgebraKit: eig_full, eig_trunc, @@ -24,24 +24,35 @@ using MatrixAlgebraKit: using Test: @test, @test_throws, @testset @testset "MatrixAlgebraKit" begin - x = randn(2, 2) - y = randn(3, 3) - a = x ⊗ y - ah = Hermitian(x) ⊗ Hermitian(y) + elt = Float32 + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) d, v = eig_full(a) @test a * v ≈ v * d + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) @test_throws MethodError eig_trunc(a) + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) d = eig_vals(a) @test d ≈ diag(eig_full(a)[1]) - d, v = eigh_full(ah) - @test ah * v ≈ v * d + a = Hermitian(randn(elt, 2, 2)) ⊗ Hermitian(randn(elt, 3, 3)) + d, v = eigh_full(a) + @test a * v ≈ v * d + + a = Hermitian(randn(elt, 2, 2)) ⊗ Hermitian(randn(elt, 3, 3)) + @test_throws MethodError eigh_trunc(a) + + a = Hermitian(randn(elt, 2, 2)) ⊗ Hermitian(randn(elt, 3, 3)) + d = eigh_vals(a) + @test d ≈ diag(eigh_full(a)[1]) - @test_throws MethodError eigh_trunc(ah) + a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) + n = left_null(a) + @test norm(n' * a) ≈ 0 atol = √eps(real(elt)) - d = eigh_vals(ah) - @test d ≈ diag(eigh_full(ah)[1]) + a = randn(elt, 2, 3) ⊗ randn(elt, 3, 4) + n = right_null(a) + @test norm(a * n') ≈ 0 atol = √eps(real(elt)) end From 24286ecbd63ec646ebe6e33a1ac6706af1fb3322 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 28 May 2025 19:45:39 -0400 Subject: [PATCH 3/3] More factorizations and tests --- src/KroneckerArrays.jl | 167 +++++++++++++++++----------------- test/test_matrixalgebrakit.jl | 61 ++++++++++++- 2 files changed, 145 insertions(+), 83 deletions(-) diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index aa64a21..0783396 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -381,7 +381,10 @@ using MatrixAlgebraKit: TruncationStrategy, default_eig_algorithm, default_eigh_algorithm, + default_lq_algorithm, + default_polar_algorithm, default_qr_algorithm, + default_svd_algorithm, eig_full!, eig_trunc!, eig_vals!, @@ -390,7 +393,19 @@ using MatrixAlgebraKit: eigh_vals!, initialize_output, left_null!, + left_orth!, + left_polar!, + lq_compact!, + lq_full!, + qr_compact!, + qr_full!, right_null!, + right_orth!, + right_polar!, + svd_compact!, + svd_full!, + svd_trunc!, + svd_vals!, truncate! struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm @@ -398,99 +413,87 @@ struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm b::B end -function MatrixAlgebraKit.default_eig_algorithm(a::KroneckerMatrix) - return KroneckerAlgorithm(default_eig_algorithm(a.a), default_eig_algorithm(a.b)) -end -function MatrixAlgebraKit.initialize_output( - f::typeof(eig_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm -) - return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b) -end -function MatrixAlgebraKit.eig_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) - eig_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a) - eig_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b) - return F -end - -function MatrixAlgebraKit.truncate!( - ::typeof(eig_trunc!), - (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, - strategy::TruncationStrategy, -) - return throw(MethodError(truncate!, (eig_trunc!, (D, V), strategy))) -end - -function MatrixAlgebraKit.initialize_output( - f::typeof(eig_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm -) - return initialize_output(f, a.a, alg.a) ⊗ initialize_output(f, a.b, alg.b) -end -function MatrixAlgebraKit.eig_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) - eig_vals!(a.a, F.a, alg.a) - eig_vals!(a.b, F.b, alg.b) - return F -end - -function MatrixAlgebraKit.default_eigh_algorithm(a::KroneckerMatrix) - return KroneckerAlgorithm(default_eigh_algorithm(a.a), default_eigh_algorithm(a.b)) -end -function MatrixAlgebraKit.initialize_output( - f::typeof(eigh_full!), a::KroneckerMatrix, alg::KroneckerAlgorithm -) - return initialize_output(f, a.a, alg.a) .⊗ initialize_output(f, a.b, alg.b) -end -function MatrixAlgebraKit.eigh_full!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) - eigh_full!(a.a, Base.Fix2(getfield, :a).(F), alg.a) - eigh_full!(a.b, Base.Fix2(getfield, :b).(F), alg.b) - return F +for f in (:eig, :eigh, :lq, :qr, :polar, :svd) + ff = Symbol("default_", f, "_algorithm") + @eval begin + function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...) + return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...)) + end + end end -function MatrixAlgebraKit.truncate!( - ::typeof(eigh_trunc!), - (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, - strategy::TruncationStrategy, +for f in ( + :eig_full!, + :eigh_full!, + :qr_compact!, + :qr_full!, + :left_polar!, + :lq_compact!, + :lq_full!, + :right_polar!, + :svd_compact!, + :svd_full!, ) - return throw(MethodError(truncate!, (eigh_trunc!, (D, V), strategy))) + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm + ) + return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...) + $f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...) + $f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...) + return F + end + end end -function MatrixAlgebraKit.initialize_output( - f::typeof(eigh_vals!), a::KroneckerMatrix, alg::KroneckerAlgorithm -) - return initialize_output(f, a.a, alg.a) ⊗ initialize_output(f, a.b, alg.b) -end -function MatrixAlgebraKit.eigh_vals!(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) - eigh_vals!(a.a, F.a, alg.a) - eigh_vals!(a.b, F.b, alg.b) - return F +for f in (:eig_vals!, :eigh_vals!, :svd_vals!) + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm + ) + return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) + $f(a.a, F.a, alg.a) + $f(a.b, F.b, alg.b) + return F + end + end end -function MatrixAlgebraKit.default_qr_algorithm(a::KroneckerMatrix; kwargs...) - return KroneckerAlgorithm( - default_qr_algorithm(a.a; kwargs...), default_qr_algorithm(a.b; kwargs...) - ) -end -function MatrixAlgebraKit.default_lq_algorithm(a::KroneckerMatrix; kwargs...) - return KroneckerAlgorithm( - default_lq_algorithm(a.a; kwargs...), default_lq_algorithm(a.b; kwargs...) - ) +for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!) + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), + (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, + strategy::TruncationStrategy, + ) + return throw(MethodError(truncate!, ($f, (D, V), strategy))) + end + end end -function MatrixAlgebraKit.initialize_output(f::typeof(left_null!), a::KroneckerMatrix) - return initialize_output(f, a.a) ⊗ initialize_output(f, a.b) -end -function MatrixAlgebraKit.left_null!(a::KroneckerMatrix, F; kwargs...) - left_null!(a.a, F.a; kwargs...) - left_null!(a.b, F.b; kwargs...) - return F +for f in (:left_orth!, :right_orth!) + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) + end + end end -function MatrixAlgebraKit.initialize_output(f::typeof(right_null!), a::KroneckerMatrix) - return initialize_output(f, a.a) ⊗ initialize_output(f, a.b) -end -function MatrixAlgebraKit.right_null!(a::KroneckerMatrix, F; kwargs...) - right_null!(a.a, F.a; kwargs...) - right_null!(a.b, F.b; kwargs...) - return F +for f in (:left_null!, :right_null!) + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...) + $f(a.a, F.a; kwargs...) + $f(a.b, F.b; kwargs...) + return F + end + end end end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index f5b4baf..c3dc399 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -1,5 +1,5 @@ using KroneckerArrays: ⊗ -using LinearAlgebra: Hermitian, diag, norm +using LinearAlgebra: Hermitian, I, diag, norm using MatrixAlgebraKit: eig_full, eig_trunc, @@ -48,6 +48,26 @@ using Test: @test, @test_throws, @testset d = eigh_vals(a) @test d ≈ diag(eigh_full(a)[1]) + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = qr_compact(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = qr_full(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = lq_compact(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = lq_full(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) n = left_null(a) @test norm(n' * a) ≈ 0 atol = √eps(real(elt)) @@ -55,4 +75,43 @@ using Test: @test, @test_throws, @testset a = randn(elt, 2, 3) ⊗ randn(elt, 3, 4) n = right_null(a) @test norm(a * n') ≈ 0 atol = √eps(real(elt)) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = left_orth(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = right_orth(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = left_polar(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = right_polar(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, s, v = svd_compact(a) + @test u * s * v ≈ a + @test collect(u'u) ≈ I + @test collect(v * v') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, s, v = svd_full(a) + @test u * s * v ≈ a + @test collect(u'u) ≈ I + @test collect(v * v') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test_throws MethodError svd_trunc(a) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + s = svd_vals(a) + @test s ≈ diag(svd_compact(a)[2]) end