diff --git a/Project.toml b/Project.toml index 8da5421..38353b5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 3ae3fad..0bce6f8 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -11,6 +11,11 @@ end arguments(a::CartesianProduct) = (a.a, a.b) arguments(a::CartesianProduct, n::Int) = arguments(a)[n] +function Base.show(io::IO, a::CartesianProduct) + print(io, a.a, " × ", a.b) + return nothing +end + ×(a, b) = CartesianProduct(a, b) Base.length(a::CartesianProduct) = length(a.a) * length(a.b) Base.getindex(a::CartesianProduct, i::CartesianProduct) = a.a[i.a] × a.b[i.b] @@ -130,6 +135,8 @@ function interleave(x::Tuple, y::Tuple) xy = ntuple(i -> (x[i], y[i]), length(x)) return flatten(xy) end +# TODO: Maybe use scalar indexing based on KroneckerProducts.jl logic for cartesian indexing: +# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 function kron_nd(a::AbstractArray{<:Any,N}, b::AbstractArray{<:Any,N}) where {N} a′ = reshape(a, interleave(size(a), ntuple(one, N))) b′ = reshape(b, interleave(ntuple(one, N), size(b))) @@ -183,6 +190,9 @@ function Base.getindex(a::KroneckerArray, i::Integer) return a[CartesianIndices(a)[i]] end +# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing +# in the n-dimensional case and use it to replace the matrix and vector cases: +# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66 function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N} return error("Not implemented.") end @@ -222,6 +232,10 @@ end function Base.inv(a::KroneckerArray) return inv(a.a) ⊗ inv(a.b) end +using LinearAlgebra: LinearAlgebra, pinv +function LinearAlgebra.pinv(a::KroneckerArray; kwargs...) + return pinv(a.a; kwargs...) ⊗ pinv(a.b; kwargs...) +end function Base.transpose(a::KroneckerArray) return transpose(a.a) ⊗ transpose(a.b) end @@ -297,6 +311,7 @@ using LinearAlgebra: Diagonal, Eigen, SVD, + det, diag, eigen, eigvals, @@ -335,9 +350,63 @@ end function LinearAlgebra.norm(a::KroneckerArray, p::Int=2) return norm(a.a, p) ⊗ norm(a.b, p) end + +using MatrixAlgebraKit: MatrixAlgebraKit, diagview +function MatrixAlgebraKit.diagview(a::KroneckerMatrix) + return diagview(a.a) ⊗ diagview(a.b) +end function LinearAlgebra.diag(a::KroneckerArray) - return diag(a.a) ⊗ diag(a.b) + return copy(diagview(a.a)) ⊗ copy(diagview(a.b)) +end + +# Matrix functions +const MATRIX_FUNCTIONS = [ + :exp, + :cis, + :log, + :sqrt, + :cbrt, + :cos, + :sin, + :tan, + :csc, + :sec, + :cot, + :cosh, + :sinh, + :tanh, + :csch, + :sech, + :coth, + :acos, + :asin, + :atan, + :acsc, + :asec, + :acot, + :acosh, + :asinh, + :atanh, + :acsch, + :asech, + :acoth, +] + +for f in MATRIX_FUNCTIONS + @eval begin + function Base.$f(a::KroneckerArray) + return throw(ArgumentError("Generic KroneckerArray `$($f)` is not supported.")) + end + end +end + +using LinearAlgebra: checksquare +function LinearAlgebra.det(a::KroneckerArray) + checksquare(a.a) + checksquare(a.b) + return det(a.a) ^ size(a.b, 1) * det(a.b) ^ size(a.a, 1) end + function LinearAlgebra.svd(a::KroneckerArray) Fa = svd(a.a) Fb = svd(a.b) @@ -690,18 +759,6 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!] end end -for f in [:eig_trunc!, :eigh_trunc!, :svd_trunc!] - @eval begin - function MatrixAlgebraKit.truncate!( - ::typeof($f), - (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, - strategy::TruncationStrategy, - ) - return throw(MethodError(truncate!, ($f, (D, V), strategy))) - end - end -end - for f in [:left_orth!, :right_orth!] @eval begin function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) @@ -941,4 +998,110 @@ for f in [:eig_vals!, :eigh_vals!, :svd_vals!] end end +using MatrixAlgebraKit: TruncationStrategy, diagview, findtruncated, truncate! + +struct KroneckerTruncationStrategy{T<:TruncationStrategy} <: TruncationStrategy + strategy::T +end + +# Avoid instantiating the identity. +function Base.getindex(a::SquareEyeKronecker, I::Vararg{CartesianProduct{Colon},2}) + return a.a ⊗ a.b[I[1].b, I[2].b] +end +function Base.getindex(a::KroneckerSquareEye, I::Vararg{CartesianProduct{<:Any,Colon},2}) + return a.a[I[1].a, I[2].a] ⊗ a.b +end +function Base.getindex(a::SquareEyeSquareEye, I::Vararg{CartesianProduct{Colon,Colon},2}) + return a +end + +using FillArrays: OnesVector +const OnesKroneckerVector{T,A<:OnesVector{T},B<:AbstractVector{T}} = KroneckerVector{T,A,B} +const KroneckerOnesVector{T,A<:AbstractVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} +const OnesVectorOnesVector{T,A<:OnesVector{T},B<:OnesVector{T}} = KroneckerVector{T,A,B} + +function MatrixAlgebraKit.findtruncated( + values::OnesKroneckerVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.a, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.a == i, prods) == length(values.a) + end + return (:) × I_data +end +function MatrixAlgebraKit.findtruncated( + values::KroneckerOnesVector, strategy::KroneckerTruncationStrategy +) + I = findtruncated(Vector(values), strategy.strategy) + prods = collect(only(axes(values)).product)[I] + I_data = unique(map(x -> x.b, prods)) + # Drop truncations that occur within the identity. + I_data = filter(I_data) do i + return count(x -> x.b == i, prods) == length(values.b) + end + return I_data × (:) +end +function MatrixAlgebraKit.findtruncated( + values::OnesVectorOnesVector, strategy::KroneckerTruncationStrategy +) + return throw(ArgumentError("Can't truncate Eye ⊗ Eye.")) +end + +for f in [:eig_trunc!, :eigh_trunc!] + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), DV::NTuple{2,KroneckerMatrix}, strategy::TruncationStrategy + ) + return truncate!($f, DV, KroneckerTruncationStrategy(strategy)) + end + function MatrixAlgebraKit.truncate!( + ::typeof($f), (D, V)::NTuple{2,KroneckerMatrix}, strategy::KroneckerTruncationStrategy + ) + I = findtruncated(diagview(D), strategy) + return (D[I, I], V[(:) × (:), I]) + end + end +end + +function MatrixAlgebraKit.truncate!( + f::typeof(svd_trunc!), USVᴴ::NTuple{3,KroneckerMatrix}, strategy::TruncationStrategy +) + return truncate!(f, USVᴴ, KroneckerTruncationStrategy(strategy)) +end +function MatrixAlgebraKit.truncate!( + ::typeof(svd_trunc!), + (U, S, Vᴴ)::NTuple{3,KroneckerMatrix}, + strategy::KroneckerTruncationStrategy, +) + I = findtruncated(diagview(S), strategy) + return (U[(:) × (:), I], S[I, I], Vᴴ[I, (:) × (:)]) +end + +for f in MATRIX_FUNCTIONS + @eval begin + function Base.$f(a::SquareEyeKronecker) + return a.a ⊗ $f(a.b) + end + function Base.$f(a::KroneckerSquareEye) + return $f(a.a) ⊗ a.b + end + function Base.$f(a::SquareEyeSquareEye) + return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported.")) + end + end +end + +function LinearAlgebra.pinv(a::SquareEyeKronecker; kwargs...) + return a.a ⊗ pinv(a.b; kwargs...) +end +function LinearAlgebra.pinv(a::KroneckerSquareEye; kwargs...) + return pinv(a.a; kwargs...) ⊗ a.b +end +function LinearAlgebra.pinv(a::SquareEyeSquareEye; kwargs...) + return a +end + end diff --git a/test/Project.toml b/test/Project.toml index aebc8e9..7f96924 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" @@ -16,6 +17,7 @@ KroneckerArrays = "0.1" LinearAlgebra = "1.10" MatrixAlgebraKit = "0.2" SafeTestsets = "0.1" +StableRNGs = "1.0" Suppressor = "0.2" Test = "1.10" TestExtras = "0.3" diff --git a/test/test_basics.jl b/test/test_basics.jl index cb57239..c2aaedc 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,7 +1,8 @@ using FillArrays: Eye using KroneckerArrays: KroneckerArrays, ⊗, ×, diagonal, kron_nd -using LinearAlgebra: Diagonal, I, eigen, eigvals, lq, qr, svd, svdvals, tr -using Test: @test, @testset +using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr +using StableRNGs: StableRNG +using Test: @test, @test_broken, @test_throws, @testset const elts = (Float32, Float64, ComplexF32, ComplexF64) @testset "KroneckerArrays (eltype=$elt)" for elt in elts @@ -35,7 +36,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) @test iszero(a - a) @test collect(a + c) ≈ collect(a) + collect(c) @test collect(b + c) ≈ collect(b) + collect(c) - for f in (transpose, adjoint, inv) + for f in (transpose, adjoint, inv, pinv) @test collect(f(a)) ≈ f(collect(a)) end @test tr(a) ≈ tr(collect(a)) @@ -66,9 +67,25 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) Q, R = qr(a) @test collect(Q * R) ≈ collect(a) @test collect(Q'Q) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test det(a) ≈ det(collect(a)) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + for f in KroneckerArrays.MATRIX_FUNCTIONS + @eval begin + @test_throws ArgumentError $f($a) + end + end end @testset "FillArrays.Eye" begin + MATRIX_FUNCTIONS = KroneckerArrays.MATRIX_FUNCTIONS + if VERSION < v"1.11-" + # `cbrt(::AbstractMatrix{<:Real})` was implemented in Julia 1.11. + MATRIX_FUNCTIONS = setdiff(MATRIX_FUNCTIONS, [:cbrt]) + end + a = Eye(2) ⊗ randn(3, 3) @test size(a) == (6, 6) @test a + a == Eye(2) ⊗ (2a.b) @@ -80,4 +97,66 @@ end @test a + a == (2a.a) ⊗ Eye(2) @test 2a == (2a.a) ⊗ Eye(2) @test a * a == (a.a * a.a) ⊗ Eye(2) + + # Eye ⊗ A + rng = StableRNG(123) + a = Eye(2) ⊗ randn(rng, 3, 3) + for f in MATRIX_FUNCTIONS + @eval begin + fa = $f($a) + @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) + @test fa.a isa Eye + end + end + + fa = inv(a) + @test collect(fa) ≈ inv(collect(a)) + @test fa.a isa Eye + + fa = pinv(a) + @test collect(fa) ≈ pinv(collect(a)) + @test fa.a isa Eye + + @test det(a) ≈ det(collect(a)) + + # A ⊗ Eye + rng = StableRNG(123) + a = randn(rng, 3, 3) ⊗ Eye(2) + for f in setdiff(MATRIX_FUNCTIONS, [:atanh]) + @eval begin + fa = $f($a) + @test collect(fa) ≈ $f(collect($a)) rtol = ∜(eps(real(eltype($a)))) + @test fa.b isa Eye + end + end + + fa = inv(a) + @test collect(fa) ≈ inv(collect(a)) + @test fa.b isa Eye + + fa = pinv(a) + @test collect(fa) ≈ pinv(collect(a)) + @test fa.b isa Eye + + @test det(a) ≈ det(collect(a)) + + # Eye ⊗ Eye + a = Eye(2) ⊗ Eye(2) + for f in KroneckerArrays.MATRIX_FUNCTIONS + @eval begin + @test_throws ArgumentError $f($a) + end + end + + fa = inv(a) + @test fa == a + @test fa.a isa Eye + @test fa.b isa Eye + + fa = pinv(a) + @test fa == a + @test fa.a isa Eye + @test fa.b isa Eye + + @test det(a) ≈ det(collect(a)) ≈ 1 end diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index 82943cd..41983e5 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -121,12 +121,6 @@ herm(a) = parent(hermitianpart(a)) end @testset "MatrixAlgebraKit + Eye" begin - - # TODO: - # eig_trunc - # eigh_trunc - # svd_trunc - for f in (eig_full, eigh_full) a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) d, v = @constinferred f(a) @@ -149,6 +143,27 @@ end @test arguments(v, 2) isa Eye end + for f in (eig_trunc, eigh_trunc) + a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye + @test arguments(v, 1) isa Eye + @test size(d) == (6, 6) + @test size(v) == (9, 6) + + a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) + d, v = f(a; trunc=(; maxrank=7)) + @test a * v ≈ v * d + @test arguments(d, 2) isa Eye + @test arguments(v, 2) isa Eye + @test size(d) == (6, 6) + @test size(v) == (9, 6) + + a = Eye(3) ⊗ Eye(3) + @test_throws ArgumentError f(a) + end + for f in (eig_vals, eigh_vals) a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) d = @constinferred f(a) @@ -221,6 +236,33 @@ end @test arguments(v, 2) isa Eye end + # svd_trunc + a = Eye(3) ⊗ randn(3, 3) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 1) isa Eye + @test arguments(s, 1) isa Eye + @test arguments(v, 1) isa Eye + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + + a = randn(3, 3) ⊗ Eye(3) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 2) isa Eye + @test arguments(s, 2) isa Eye + @test arguments(v, 2) isa Eye + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + + a = Eye(3) ⊗ Eye(3) + @test_throws ArgumentError svd_trunc(a) + + # svd_vals a = Eye(3) ⊗ randn(3, 3) d = @constinferred svd_vals(a) d′ = svd_vals(Matrix(a))