diff --git a/Project.toml b/Project.toml index c612160..8e8934f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,17 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.0" +version = "0.1.1" [deps] +DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" [compat] +DerivableInterfaces = "0.4.5" +GPUArraysCore = "0.2.0" LinearAlgebra = "1.10" +MatrixAlgebraKit = "0.2.0" julia = "1.10" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index e91790a..0783396 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -1,5 +1,7 @@ module KroneckerArrays +using GPUArraysCore: GPUArraysCore + export ⊗, × struct CartesianProduct{A,B} @@ -28,6 +30,26 @@ end Base.first(r::CartesianProductUnitRange) = first(r.range) Base.last(r::CartesianProductUnitRange) = last(r.range) +function Base.axes(r::CartesianProductUnitRange) + return (CartesianProductUnitRange(r.product, only(axes(r.range))),) +end + +using Base.Broadcast: DefaultArrayStyle +for f in (:+, :-) + @eval begin + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), r::CartesianProductUnitRange, x::Integer + ) + return CartesianProductUnitRange(r.product, $f.(r.range, x)) + end + function Broadcast.broadcasted( + ::DefaultArrayStyle{1}, ::typeof($f), x::Integer, r::CartesianProductUnitRange + ) + return CartesianProductUnitRange(r.product, $f.(x, r.range)) + end + end +end + struct KroneckerArray{T,N,A<:AbstractArray{T,N},B<:AbstractArray{T,N}} <: AbstractArray{T,N} a::A b::B @@ -44,6 +66,15 @@ end const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B} const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B} +function Base.copy(a::KroneckerArray) + return copy(a.a) ⊗ copy(a.b) +end +function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) + copyto!(dest.a, src.a) + copyto!(dest.b, src.b) + return dest +end + function Base.similar( a::AbstractArray, elt::Type, @@ -73,9 +104,21 @@ function Base.similar( return similar(arrayt, map(ax -> ax.product.a, axs)) ⊗ similar(arrayt, map(ax -> ax.product.b, axs)) end +function Base.similar( + arrayt::Type{<:KroneckerArray{<:Any,<:Any,A,B}}, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) where {A,B} + return similar(A, map(ax -> ax.product.a, axs)) ⊗ similar(B, map(ax -> ax.product.b, axs)) +end Base.collect(a::KroneckerArray) = kron(a.a, a.b) +function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N} + return convert(Array{T,N}, collect(a)) +end + Base.size(a::KroneckerArray) = ntuple(dim -> size(a.a, dim) * size(a.b, dim), ndims(a)) function Base.axes(a::KroneckerArray) @@ -107,12 +150,23 @@ end ⊗(a::Number, b::AbstractVecOrMat) = a * b ⊗(a::AbstractVecOrMat, b::Number) = a * b -function Base.getindex(::KroneckerArray, ::Int) - return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported.")) +function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer) + GPUArraysCore.assertscalar("getindex") + # Code logic from Kronecker.jl: + # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105 + k, l = size(a.b) + return a.a[cld(i1, k), cld(i2, l)] * a.b[(i1 - 1) % k + 1, (i2 - 1) % l + 1] +end +function Base.getindex(a::KroneckerMatrix, i::Integer) + return a[CartesianIndices(a)[i]] end -function Base.getindex(::KroneckerArray{<:Any,N}, ::Vararg{Int,N}) where {N} - return throw(ArgumentError("Scalar indexing of KroneckerArray is not supported.")) + +function Base.getindex(a::KroneckerVector, i::Integer) + GPUArraysCore.assertscalar("getindex") + k = length(a.b) + return a.a[cld(i, k)] * a.b[(i - 1) % k + 1] end + function Base.getindex(a::KroneckerVector, i::CartesianProduct) return a.a[i.a] ⊗ a.b[i.b] end @@ -169,9 +223,18 @@ end function Base.:*(a::KroneckerArray, b::KroneckerArray) return (a.a * b.a) ⊗ (a.b * b.b) end -function LinearAlgebra.mul!(c::KroneckerArray, a::KroneckerArray, b::KroneckerArray) +function LinearAlgebra.mul!( + c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number +) + iszero(β) || + iszero(c) || + throw( + ArgumentError( + "Can't multiple KroneckerArrays with nonzero β and nonzero destination." + ), + ) mul!(c.a, a.a, b.a) - mul!(c.b, a.b, b.b) + mul!(c.b, a.b, b.b, α, β) return c end function LinearAlgebra.tr(a::KroneckerArray) @@ -269,4 +332,168 @@ for op in (:+, :-) end end +function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray) + dest.a .= a.a + dest.b .= a.b + return dest +end +function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray) + if a.b == b.b + map!(+, dest.a, a.a, b.a) + dest.b .= a.b + elseif a.a == b.a + dest.a .= a.a + map!(+, dest.b, a.b, b.b) + else + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + return dest +end +function Base.map!( + f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray +) + dest.a .= f.x .* a.a + dest.b .= a.b + return dest +end +function Base.map!( + f::Base.Fix2{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray +) + dest.a .= a.a + dest.b .= a.b .* f.x + return dest +end + +using DerivableInterfaces: DerivableInterfaces, zero! +function DerivableInterfaces.zero!(a::KroneckerArray) + zero!(a.a) + zero!(a.b) + return a +end + +using MatrixAlgebraKit: + MatrixAlgebraKit, + AbstractAlgorithm, + TruncationStrategy, + default_eig_algorithm, + default_eigh_algorithm, + default_lq_algorithm, + default_polar_algorithm, + default_qr_algorithm, + default_svd_algorithm, + eig_full!, + eig_trunc!, + eig_vals!, + eigh_full!, + eigh_trunc!, + eigh_vals!, + initialize_output, + left_null!, + left_orth!, + left_polar!, + lq_compact!, + lq_full!, + qr_compact!, + qr_full!, + right_null!, + right_orth!, + right_polar!, + svd_compact!, + svd_full!, + svd_trunc!, + svd_vals!, + truncate! + +struct KroneckerAlgorithm{A,B} <: AbstractAlgorithm + a::A + b::B +end + +for f in (:eig, :eigh, :lq, :qr, :polar, :svd) + ff = Symbol("default_", f, "_algorithm") + @eval begin + function MatrixAlgebraKit.$ff(a::KroneckerMatrix; kwargs...) + return KroneckerAlgorithm($ff(a.a; kwargs...), $ff(a.b; kwargs...)) + end + end +end + +for f in ( + :eig_full!, + :eigh_full!, + :qr_compact!, + :qr_full!, + :left_polar!, + :lq_compact!, + :lq_full!, + :right_polar!, + :svd_compact!, + :svd_full!, +) + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm + ) + return initialize_output($f, a.a, alg.a) .⊗ initialize_output($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm; kwargs...) + $f(a.a, Base.Fix2(getfield, :a).(F), alg.a; kwargs...) + $f(a.b, Base.Fix2(getfield, :b).(F), alg.b; kwargs...) + return F + end + end +end + +for f in (:eig_vals!, :eigh_vals!, :svd_vals!) + @eval begin + function MatrixAlgebraKit.initialize_output( + ::typeof($f), a::KroneckerMatrix, alg::KroneckerAlgorithm + ) + return initialize_output($f, a.a, alg.a) ⊗ initialize_output($f, a.b, alg.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F, alg::KroneckerAlgorithm) + $f(a.a, F.a, alg.a) + $f(a.b, F.b, alg.b) + return F + end + end +end + +for f in (:eig_trunc!, :eigh_trunc!, :svd_trunc!) + @eval begin + function MatrixAlgebraKit.truncate!( + ::typeof($f), + (D, V)::Tuple{KroneckerMatrix,KroneckerMatrix}, + strategy::TruncationStrategy, + ) + return throw(MethodError(truncate!, ($f, (D, V), strategy))) + end + end +end + +for f in (:left_orth!, :right_orth!) + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + return initialize_output($f, a.a) .⊗ initialize_output($f, a.b) + end + end +end + +for f in (:left_null!, :right_null!) + @eval begin + function MatrixAlgebraKit.initialize_output(::typeof($f), a::KroneckerMatrix) + return initialize_output($f, a.a) ⊗ initialize_output($f, a.b) + end + function MatrixAlgebraKit.$f(a::KroneckerMatrix, F; kwargs...) + $f(a.a, F.a; kwargs...) + $f(a.b, F.b; kwargs...) + return F + end + end +end + end diff --git a/test/Project.toml b/test/Project.toml index 3b27675..772eff2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl new file mode 100644 index 0000000..c3dc399 --- /dev/null +++ b/test/test_matrixalgebrakit.jl @@ -0,0 +1,117 @@ +using KroneckerArrays: ⊗ +using LinearAlgebra: Hermitian, I, diag, norm +using MatrixAlgebraKit: + eig_full, + eig_trunc, + eig_vals, + eigh_full, + eigh_trunc, + eigh_vals, + left_null, + left_orth, + left_polar, + lq_compact, + lq_full, + qr_compact, + qr_full, + right_null, + right_orth, + right_polar, + svd_compact, + svd_full, + svd_trunc, + svd_vals +using Test: @test, @test_throws, @testset + +@testset "MatrixAlgebraKit" begin + elt = Float32 + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + d, v = eig_full(a) + @test a * v ≈ v * d + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test_throws MethodError eig_trunc(a) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + d = eig_vals(a) + @test d ≈ diag(eig_full(a)[1]) + + a = Hermitian(randn(elt, 2, 2)) ⊗ Hermitian(randn(elt, 3, 3)) + d, v = eigh_full(a) + @test a * v ≈ v * d + + a = Hermitian(randn(elt, 2, 2)) ⊗ Hermitian(randn(elt, 3, 3)) + @test_throws MethodError eigh_trunc(a) + + a = Hermitian(randn(elt, 2, 2)) ⊗ Hermitian(randn(elt, 3, 3)) + d = eigh_vals(a) + @test d ≈ diag(eigh_full(a)[1]) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = qr_compact(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = qr_full(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = lq_compact(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = lq_full(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 3, 2) ⊗ randn(elt, 4, 3) + n = left_null(a) + @test norm(n' * a) ≈ 0 atol = √eps(real(elt)) + + a = randn(elt, 2, 3) ⊗ randn(elt, 3, 4) + n = right_null(a) + @test norm(a * n') ≈ 0 atol = √eps(real(elt)) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = left_orth(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = right_orth(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, c = left_polar(a) + @test u * c ≈ a + @test collect(u'u) ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + c, u = right_polar(a) + @test c * u ≈ a + @test collect(u * u') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, s, v = svd_compact(a) + @test u * s * v ≈ a + @test collect(u'u) ≈ I + @test collect(v * v') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + u, s, v = svd_full(a) + @test u * s * v ≈ a + @test collect(u'u) ≈ I + @test collect(v * v') ≈ I + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + @test_throws MethodError svd_trunc(a) + + a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) + s = svd_vals(a) + @test s ≈ diag(svd_compact(a)[2]) +end