diff --git a/Project.toml b/Project.toml index c091764..c961002 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.8" +version = "0.1.9" [deps] DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" +DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -12,6 +13,7 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4" [compat] DerivableInterfaces = "0.4.5" +DiagonalArrays = "0.3.5" FillArrays = "1.13.0" GPUArraysCore = "0.2.0" LinearAlgebra = "1.10" diff --git a/src/KroneckerArrays.jl b/src/KroneckerArrays.jl index 92ca131..001aee0 100644 --- a/src/KroneckerArrays.jl +++ b/src/KroneckerArrays.jl @@ -250,6 +250,9 @@ end function Base.iszero(a::KroneckerArray) return iszero(a.a) || iszero(a.b) end +function Base.isreal(a::KroneckerArray) + return isreal(a.a) && isreal(a.b) +end function Base.inv(a::KroneckerArray) return inv(a.a) ⊗ inv(a.b) end @@ -270,6 +273,9 @@ end function Base.:*(a::KroneckerArray, b::Number) return a.a ⊗ (a.b * b) end +function Base.:/(a::KroneckerArray, b::Number) + return a * inv(b) +end function Base.:-(a::KroneckerArray) return (-a.a) ⊗ a.b @@ -291,26 +297,82 @@ for op in (:+, :-) end end +using Base.Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted +struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end +function KroneckerStyle{N}(a::BroadcastStyle, b::BroadcastStyle) where {N} + return KroneckerStyle{N,a,b}() +end +function KroneckerStyle(a::AbstractArrayStyle{N}, b::AbstractArrayStyle{N}) where {N} + return KroneckerStyle{N}(a, b) +end +function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M} + return KroneckerStyle{M,typeof(A)(v),typeof(B)(v)}() +end +function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any,N,A,B}}) where {N,A,B} + return KroneckerStyle{N}(BroadcastStyle(A), BroadcastStyle(B)) +end +function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N} + return KroneckerStyle{N}( + BroadcastStyle(style1.a, style2.a), BroadcastStyle(style1.b, style2.b) + ) +end +function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type) where {N,A,B} + ax_a = map(ax -> ax.product.a, axes(bc)) + ax_b = map(ax -> ax.product.b, axes(bc)) + bc_a = Broadcasted(A, nothing, (), ax_a) + bc_b = Broadcasted(B, nothing, (), ax_b) + a = similar(bc_a, elt) + b = similar(bc_b, elt) + return a ⊗ b +end +function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:KroneckerStyle}) + return throw( + ArgumentError( + "Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", + ), + ) +end + +function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...) + return throw( + ArgumentError( + "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", + ), + ) +end +function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...) + return throw( + ArgumentError( + "Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure.", + ), + ) +end function Base.map!(::typeof(identity), dest::KroneckerArray, a::KroneckerArray) dest.a .= a.a dest.b .= a.b return dest end -function Base.map!(::typeof(+), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray) - if a.b == b.b - map!(+, dest.a, a.a, b.a) - dest.b .= a.b - elseif a.a == b.a - dest.a .= a.a - map!(+, dest.b, a.b, b.b) - else - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), +for f in [:+, :-] + @eval begin + function Base.map!( + ::typeof($f), dest::KroneckerArray, a::KroneckerArray, b::KroneckerArray ) + if a.b == b.b + map!($f, dest.a, a.a, b.a) + dest.b .= a.b + elseif a.a == b.a + dest.a .= a.a + map!($f, dest.b, a.b, b.b) + else + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + return dest + end end - return dest end function Base.map!( f::Base.Fix1{typeof(*),<:Number}, dest::KroneckerArray, a::KroneckerArray @@ -326,6 +388,16 @@ function Base.map!( dest.b .= f.f.(a.b, f.x) return dest end +function Base.map!( + f::Base.Fix2{typeof(/),<:Number}, dest::KroneckerArray, a::KroneckerArray +) + return map!(Base.Fix2(*, inv(f.x)), dest, a) +end +function Base.map!(::typeof(conj), dest::KroneckerArray, a::KroneckerArray) + dest.a .= conj.(a.a) + dest.b .= conj.(a.b) + return dest +end using LinearAlgebra: LinearAlgebra, @@ -343,9 +415,10 @@ using LinearAlgebra: svd, svdvals, tr -diagonal(a::AbstractArray) = Diagonal(a) -function diagonal(a::KroneckerArray) - return Diagonal(a.a) ⊗ Diagonal(a.b) + +using DiagonalArrays: DiagonalArrays, diagonal +function DiagonalArrays.diagonal(a::KroneckerArray) + return diagonal(a.a) ⊗ diagonal(a.b) end function Base.:*(a::KroneckerArray, b::KroneckerArray) @@ -372,6 +445,23 @@ function LinearAlgebra.norm(a::KroneckerArray, p::Int=2) return norm(a.a, p) ⊗ norm(a.b, p) end +function Base.real(a::KroneckerArray) + if iszero(imag(a.a)) || iszero(imag(a.b)) + return real(a.a) ⊗ real(a.b) + elseif iszero(real(a.a)) || iszero(real(a.b)) + return -imag(a.a) ⊗ imag(a.b) + end + return real(a.a) ⊗ real(a.b) - imag(a.a) ⊗ imag(a.b) +end +function Base.imag(a::KroneckerArray) + if iszero(imag(a.a)) || iszero(real(a.b)) + return real(a.a) ⊗ imag(a.b) + elseif iszero(real(a.a)) || iszero(imag(a.b)) + return imag(a.a) ⊗ real(a.b) + end + return real(a.a) ⊗ imag(a.b) + imag(a.a) ⊗ real(a.b) +end + using MatrixAlgebraKit: MatrixAlgebraKit, diagview function MatrixAlgebraKit.diagview(a::KroneckerMatrix) return diagview(a.a) ⊗ diagview(a.b) @@ -506,6 +596,19 @@ const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B} const KroneckerEye{T,A<:AbstractMatrix{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} const EyeEye{T,A<:Eye{T},B<:Eye{T}} = KroneckerMatrix{T,A,B} +using DerivableInterfaces: DerivableInterfaces, zero! +function DerivableInterfaces.zero!(a::EyeKronecker) + zero!(a.b) + return a +end +function DerivableInterfaces.zero!(a::KroneckerEye) + zero!(a.a) + return a +end +function DerivableInterfaces.zero!(a::EyeEye) + return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`.")) +end + function Base.:*(a::Number, b::EyeKronecker) return b.a ⊗ (a * b.b) end @@ -580,29 +683,44 @@ end function Base.map!(::typeof(identity), dest::EyeEye, a::EyeEye) return error("Can't write in-place.") end -function Base.map!(f::typeof(+), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) - if dest.a ≠ a.a ≠ b.a - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) +for f in [:+, :-] + @eval begin + function Base.map!(::typeof($f), dest::EyeKronecker, a::EyeKronecker, b::EyeKronecker) + if dest.a ≠ a.a ≠ b.a + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + map!($f, dest.b, a.b, b.b) + return dest + end + function Base.map!(::typeof($f), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) + if dest.b ≠ a.b ≠ b.b + throw( + ArgumentError( + "KroneckerArray addition is only supported when the first or second arguments match.", + ), + ) + end + map!($f, dest.a, a.a, b.a) + return dest + end + function Base.map!(::typeof($f), dest::EyeEye, a::EyeEye, b::EyeEye) + return error("Can't write in-place.") + end end - map!(f, dest.b, a.b, b.b) +end +function Base.map!(f::typeof(-), dest::EyeKronecker, a::EyeKronecker) + map!(f, dest.b, a.b) return dest end -function Base.map!(f::typeof(+), dest::KroneckerEye, a::KroneckerEye, b::KroneckerEye) - if dest.b ≠ a.b ≠ b.b - throw( - ArgumentError( - "KroneckerArray addition is only supported when the first or second arguments match.", - ), - ) - end - map!(f, dest.a, a.a, b.a) +function Base.map!(f::typeof(-), dest::KroneckerEye, a::KroneckerEye) + map!(f, dest.a, a.a) return dest end -function Base.map!(f::typeof(+), dest::EyeEye, a::EyeEye, b::EyeEye) +function Base.map!(f::typeof(-), dest::EyeEye, a::EyeEye) return error("Can't write in-place.") end function Base.map!(f::Base.Fix1{typeof(*),<:Number}, dest::EyeKronecker, a::EyeKronecker) @@ -812,6 +930,74 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr const KroneckerSquareEye{T,A<:AbstractMatrix{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} const SquareEyeSquareEye{T,A<:SquareEye{T},B<:SquareEye{T}} = KroneckerMatrix{T,A,B} +# Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`. +function Base.similar( + a::SquareEyeKronecker, + elt::Type, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) + ax_a = map(ax -> ax.product.a, axs) + ax_b = map(ax -> ax.product.b, axs) + eye_ax_a = (only(unique(ax_a)),) + return Eye{elt}(eye_ax_a) ⊗ similar(a.b, elt, ax_b) +end +function Base.similar( + a::KroneckerSquareEye, + elt::Type, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) + ax_a = map(ax -> ax.product.a, axs) + ax_b = map(ax -> ax.product.b, axs) + eye_ax_b = (only(unique(ax_b)),) + return similar(a.a, elt, ax_a) ⊗ Eye{elt}(eye_ax_b) +end +function Base.similar( + a::SquareEyeSquareEye, + elt::Type, + axs::Tuple{ + CartesianProductUnitRange{<:Integer},Vararg{CartesianProductUnitRange{<:Integer}} + }, +) + ax_a = map(ax -> ax.product.a, axs) + ax_b = map(ax -> ax.product.b, axs) + eye_ax_a = (only(unique(ax_a)),) + eye_ax_b = (only(unique(ax_b)),) + return Eye{elt}(eye_ax_a) ⊗ Eye{elt}(eye_ax_b) +end + +function Base.similar( + arrayt::Type{<:SquareEyeKronecker{<:Any,<:Any,A}}, + axs::NTuple{2,CartesianProductUnitRange{<:Integer}}, +) where {A} + ax_a = map(ax -> ax.product.a, axs) + ax_b = map(ax -> ax.product.b, axs) + eye_ax_a = (only(unique(ax_a)),) + return Eye{eltype(arrayt)}(eye_ax_a) ⊗ similar(A, ax_b) +end +function Base.similar( + arrayt::Type{<:KroneckerSquareEye{<:Any,A}}, + axs::NTuple{2,CartesianProductUnitRange{<:Integer}}, +) where {A} + ax_a = map(ax -> ax.product.a, axs) + ax_b = map(ax -> ax.product.b, axs) + eye_ax_b = (only(unique(ax_b)),) + return similar(A, ax_a) ⊗ Eye{eltype(arrayt)}(eye_ax_b) +end +function Base.similar( + arrayt::Type{<:SquareEyeSquareEye}, axs::NTuple{2,CartesianProductUnitRange{<:Integer}} +) + elt = eltype(arrayt) + ax_a = map(ax -> ax.product.a, axs) + ax_b = map(ax -> ax.product.b, axs) + eye_ax_a = (only(unique(ax_a)),) + eye_ax_b = (only(unique(ax_b)),) + return Eye{elt}(eye_ax_a) ⊗ Eye{elt}(eye_ax_b) +end + struct SquareEyeAlgorithm{KWargs<:NamedTuple} <: AbstractAlgorithm kwargs::KWargs end @@ -884,8 +1070,6 @@ for f in [:left_null!, :right_null!] end end for f in [ - :eig_full!, - :eigh_full!, :qr_compact!, :qr_full!, :left_orth!, @@ -900,10 +1084,14 @@ for f in [ _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a) end end +_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye) = complex.((a, a)) +_initialize_output_squareeye(::typeof(eig_full!), a::SquareEye, alg) = complex.((a, a)) +_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye) = (real(a), a) +_initialize_output_squareeye(::typeof(eigh_full!), a::SquareEye, alg) = (real(a), a) for f in [:svd_compact!, :svd_full!] @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, a, a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, a, a) + _initialize_output_squareeye(::typeof($f), a::SquareEye) = (a, real(a), a) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = (a, real(a), a) end end @@ -987,10 +1175,12 @@ function MatrixAlgebraKit.right_null!( return throw(MethodError(right_null!, (a, F))) end -for f in [:eig_vals!, :eigh_vals!, :svd_vals!] +_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye) = parent(a) +_initialize_output_squareeye(::typeof(eig_vals!), a::SquareEye, alg) = parent(a) +for f in [:eigh_vals!, svd_vals!] @eval begin - _initialize_output_squareeye(::typeof($f), a::SquareEye) = parent(a) - _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = parent(a) + _initialize_output_squareeye(::typeof($f), a::SquareEye) = real(parent(a)) + _initialize_output_squareeye(::typeof($f), a::SquareEye, alg) = real(parent(a)) end end diff --git a/test/Project.toml b/test/Project.toml index 7f96924..9423c59 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/test_basics.jl b/test/test_basics.jl index 76ae661..5e05c68 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,6 +1,10 @@ +using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted +using DerivableInterfaces: zero! using FillArrays: Eye using KroneckerArrays: KroneckerArrays, + KroneckerArray, + KroneckerStyle, CartesianProductUnitRange, ⊗, ×, @@ -9,7 +13,7 @@ using KroneckerArrays: diagonal, kron_nd, unproduct -using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, pinv, qr, svd, svdvals, tr +using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr using StableRNGs: StableRNG using Test: @test, @test_broken, @test_throws, @testset @@ -41,8 +45,10 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) b = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) c = a.a ⊗ b.b + @test a isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)} @test similar(typeof(a), (2, 3)) isa Matrix{elt} @test size(similar(typeof(a), (2, 3))) == (2, 3) + @test isreal(a) == (elt <: Real) @test a[1 × 1, 1 × 1] == a.a[1, 1] * a.b[1, 1] @test a[1 × 3, 2 × 1] == a.a[1, 2] * a.b[3, 1] @test a[1 × (2:3), 2 × 1] == a.a[1, 2] * a.b[2:3, 1] @@ -60,6 +66,7 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) @test collect(-a) == -collect(a) @test collect(3 * a) ≈ 3 * collect(a) @test collect(a * 3) ≈ collect(a) * 3 + @test collect(a / 3) ≈ collect(a) / 3 @test a + a == 2a @test iszero(a - a) @test collect(a + c) ≈ collect(a) + collect(c) @@ -68,6 +75,61 @@ const elts = (Float32, Float64, ComplexF32, ComplexF64) @test collect(f(a)) ≈ f(collect(a)) end @test tr(a) ≈ tr(collect(a)) + @test norm(a) ≈ norm(collect(a)) + + # Broadcasting + style = KroneckerStyle(BroadcastStyle(typeof(a.a)), BroadcastStyle(typeof(a.b))) + @test BroadcastStyle(typeof(a)) === style + @test_throws "not supported" sin.(a) + a′ = similar(a) + @test_throws "not supported" a′ .= sin.(a) + a′ = similar(a) + @test_broken a′ .= 2 .* a + bc = broadcasted(+, a, a) + @test bc.style === style + @test similar(bc, elt) isa KroneckerArray{elt,2,typeof(a.a),typeof(a.b)} + @test_broken copy(bc) + bc = broadcasted(*, 2, a) + @test bc.style === style + @test_broken copy(bc) + + # Mapping + @test_throws "not supported" map(sin, a) + @test_broken map(Base.Fix1(*, 2), a) + a′ = similar(a) + @test_throws "not supported" map!(sin, a′, a) + a′ = similar(a) + map!(identity, a′, a) + @test collect(a′) ≈ collect(a) + a′ = similar(a) + map!(+, a′, a, a) + @test collect(a′) ≈ 2 * collect(a) + a′ = similar(a) + map!(-, a′, a, a) + @test norm(collect(a′)) ≈ 0 + a′ = similar(a) + map!(Base.Fix1(*, 2), a′, a) + @test collect(a′) ≈ 2 * collect(a) + a′ = similar(a) + map!(Base.Fix2(*, 2), a′, a) + @test collect(a′) ≈ 2 * collect(a) + a′ = similar(a) + map!(Base.Fix2(/, 2), a′, a) + @test collect(a′) ≈ collect(a) / 2 + a′ = similar(a) + map!(conj, a′, a) + @test collect(a′) ≈ conj(collect(a)) + + if elt <: Real + @test real(a) == a + else + @test_throws ArgumentError real(a) + end + if elt <: Real + @test iszero(imag(a)) + else + @test_throws ArgumentError imag(a) + end a = randn(elt, 2, 2, 2) ⊗ randn(elt, 3, 3, 3) @test collect(a) ≈ kron_nd(a.a, a.b) @@ -126,6 +188,110 @@ end @test 2a == (2a.a) ⊗ Eye(2) @test a * a == (a.a * a.a) ⊗ Eye(2) + # similar + a = Eye(2) ⊗ randn(3, 3) + for a′ in ( + similar(a), + similar(a, eltype(a)), + similar(a, axes(a)), + similar(a, eltype(a), axes(a)), + similar(typeof(a), axes(a)), + ) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} + @test a′.a === a.a + end + + a = Eye(2) ⊗ randn(3, 3) + for args in ((Float32,), (Float32, axes(a))) + a′ = similar(a, args...) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32,ndims(a)} + @test a′.a === Eye{Float32}(2) + end + + a = randn(3, 3) ⊗ Eye(2) + for a′ in ( + similar(a), + similar(a, eltype(a)), + similar(a, axes(a)), + similar(a, eltype(a), axes(a)), + similar(typeof(a), axes(a)), + ) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} + @test a′.b === a.b + end + + a = randn(3, 3) ⊗ Eye(2) + for args in ((Float32,), (Float32, axes(a))) + a′ = similar(a, args...) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32,ndims(a)} + @test a′.b === Eye{Float32}(2) + end + + a = Eye(3) ⊗ Eye(2) + for a′ in ( + similar(a), + similar(a, eltype(a)), + similar(a, axes(a)), + similar(a, eltype(a), axes(a)), + similar(typeof(a), axes(a)), + ) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{eltype(a),ndims(a),typeof(a.a),typeof(a.b)} + @test a′.a === a.a + @test a′.b === a.b + end + + a = Eye(3) ⊗ Eye(2) + for args in ((Float32,), (Float32, axes(a))) + a′ = similar(a, args...) + @test size(a′) == (6, 6) + @test a′ isa KroneckerArray{Float32,ndims(a)} + @test a′.a === Eye{Float32}(3) + @test a′.b === Eye{Float32}(2) + end + + # DerivableInterfaces.zero! + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + zero!(a) + @test iszero(a) + end + a = Eye(3) ⊗ Eye(2) + @test_throws ArgumentError zero!(a) + + # map!(+, ...) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(+, a′, a, a) + @test collect(a′) ≈ 2 * collect(a) + end + a = Eye(3) ⊗ Eye(2) + a′ = similar(a) + @test_throws ErrorException map!(+, a′, a, a) + + # map!(-, ...) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(-, a′, a, a) + @test norm(collect(a′)) ≈ 0 + end + a = Eye(3) ⊗ Eye(2) + a′ = similar(a) + @test_throws ErrorException map!(-, a′, a, a) + + # map!(-, b, a) + for a in (Eye(2) ⊗ randn(3, 3), randn(3, 3) ⊗ Eye(2)) + a′ = similar(a) + map!(-, a′, a) + @test collect(a′) ≈ -collect(a) + end + a = Eye(3) ⊗ Eye(2) + a′ = similar(a) + @test_throws ErrorException map!(-, a′, a) + # Eye ⊗ A rng = StableRNG(123) a = Eye(2) ⊗ randn(rng, 3, 3) diff --git a/test/test_matrixalgebrakit.jl b/test/test_matrixalgebrakit.jl index 41983e5..e1b96e5 100644 --- a/test/test_matrixalgebrakit.jl +++ b/test/test_matrixalgebrakit.jl @@ -44,6 +44,8 @@ herm(a) = parent(hermitianpart(a)) a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) d, v = eigh_full(a) @test a * v ≈ v * d + @test eltype(d) === real(elt) + @test eltype(v) === elt a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) @test_throws MethodError eigh_trunc(a) @@ -51,6 +53,7 @@ herm(a) = parent(hermitianpart(a)) a = herm(randn(elt, 2, 2)) ⊗ herm(randn(elt, 3, 3)) d = eigh_vals(a) @test d ≈ diag(eigh_full(a)[1]) + @test eltype(d) === real(elt) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, c = qr_compact(a) @@ -103,12 +106,18 @@ herm(a) = parent(hermitianpart(a)) a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_compact(a) @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt @test collect(u'u) ≈ I @test collect(v * v') ≈ I a = randn(elt, 2, 2) ⊗ randn(elt, 3, 3) u, s, v = svd_full(a) @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt @test collect(u'u) ≈ I @test collect(v * v') ≈ I @@ -121,26 +130,48 @@ herm(a) = parent(hermitianpart(a)) end @testset "MatrixAlgebraKit + Eye" begin - for f in (eig_full, eigh_full) - a = Eye(3) ⊗ parent(hermitianpart(randn(3, 3))) - d, v = @constinferred f(a) + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + d, v = @constinferred eig_full(a) @test a * v ≈ v * d - @test arguments(d, 1) isa Eye - @test arguments(v, 1) isa Eye + @test arguments(d, 1) isa Eye{complex(elt)} + @test arguments(v, 1) isa Eye{complex(elt)} - a = parent(hermitianpart(randn(3, 3))) ⊗ Eye(3) - d, v = @constinferred f(a) + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) + d, v = @constinferred eig_full(a) @test a * v ≈ v * d - @test arguments(d, 2) isa Eye - @test arguments(v, 2) isa Eye + @test arguments(d, 2) isa Eye{complex(elt)} + @test arguments(v, 2) isa Eye{complex(elt)} - a = Eye(3) ⊗ Eye(3) - d, v = @constinferred f(a) + a = Eye{elt}(3) ⊗ Eye{elt}(3) + d, v = @constinferred eig_full(a) @test a * v ≈ v * d - @test arguments(d, 1) isa Eye - @test arguments(d, 2) isa Eye - @test arguments(v, 1) isa Eye - @test arguments(v, 2) isa Eye + @test arguments(d, 1) isa Eye{complex(elt)} + @test arguments(d, 2) isa Eye{complex(elt)} + @test arguments(v, 1) isa Eye{complex(elt)} + @test arguments(v, 2) isa Eye{complex(elt)} + end + + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ parent(hermitianpart(randn(elt, 3, 3))) + d, v = @constinferred eigh_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + + a = parent(hermitianpart(randn(elt, 3, 3))) ⊗ Eye{elt}(3) + d, v = @constinferred eigh_full(a) + @test a * v ≈ v * d + @test arguments(d, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + + a = Eye{elt}(3) ⊗ Eye{elt}(3) + d, v = @constinferred eigh_full(a) + @test a * v ≈ v * d + @test arguments(d, 1) isa Eye{real(elt)} + @test arguments(d, 2) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + @test arguments(v, 2) isa Eye{elt} end for f in (eig_trunc, eigh_trunc) @@ -211,77 +242,110 @@ end end for f in (svd_compact, svd_full) - a = Eye(3) ⊗ randn(3, 3) - u, s, v = @constinferred f(a) - @test u * s * v ≈ a - @test arguments(u, 1) isa Eye - @test arguments(s, 1) isa Eye - @test arguments(v, 1) isa Eye - - a = randn(3, 3) ⊗ Eye(3) - u, s, v = @constinferred f(a) - @test u * s * v ≈ a - @test arguments(u, 2) isa Eye - @test arguments(s, 2) isa Eye - @test arguments(v, 2) isa Eye - - a = Eye(3) ⊗ Eye(3) - u, s, v = @constinferred f(a) - @test u * s * v ≈ a - @test arguments(u, 1) isa Eye - @test arguments(s, 1) isa Eye - @test arguments(v, 1) isa Eye - @test arguments(u, 2) isa Eye - @test arguments(s, 2) isa Eye - @test arguments(v, 2) isa Eye + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + u, s, v = @constinferred f(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 1) isa Eye{elt} + @test arguments(s, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + + a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + u, s, v = @constinferred f(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 2) isa Eye{elt} + @test arguments(s, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + + a = Eye{elt}(3) ⊗ Eye{elt}(3) + u, s, v = @constinferred f(a) + @test u * s * v ≈ a + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + @test arguments(u, 1) isa Eye{elt} + @test arguments(s, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + @test arguments(u, 2) isa Eye{elt} + @test arguments(s, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + end end # svd_trunc - a = Eye(3) ⊗ randn(3, 3) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 1) isa Eye - @test arguments(s, 1) isa Eye - @test arguments(v, 1) isa Eye - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 1) isa Eye{elt} + @test arguments(s, 1) isa Eye{real(elt)} + @test arguments(v, 1) isa Eye{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end - a = randn(3, 3) ⊗ Eye(3) - u, s, v = svd_trunc(a; trunc=(; maxrank=7)) - u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) - @test Matrix(u * s * v) ≈ u′ * s′ * v′ - @test arguments(u, 2) isa Eye - @test arguments(s, 2) isa Eye - @test arguments(v, 2) isa Eye - @test size(u) == (9, 6) - @test size(s) == (6, 6) - @test size(v) == (6, 9) + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + # TODO: Type inference is broken for `svd_trunc`, + # look into fixing it. + # u, s, v = @constinferred svd_trunc(a; trunc=(; maxrank=7)) + u, s, v = svd_trunc(a; trunc=(; maxrank=7)) + @test eltype(u) === elt + @test eltype(s) === real(elt) + @test eltype(v) === elt + u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=6)) + @test Matrix(u * s * v) ≈ u′ * s′ * v′ + @test arguments(u, 2) isa Eye{elt} + @test arguments(s, 2) isa Eye{real(elt)} + @test arguments(v, 2) isa Eye{elt} + @test size(u) == (9, 6) + @test size(s) == (6, 6) + @test size(v) == (6, 9) + end a = Eye(3) ⊗ Eye(3) @test_throws ArgumentError svd_trunc(a) # svd_vals - a = Eye(3) ⊗ randn(3, 3) - d = @constinferred svd_vals(a) - d′ = svd_vals(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ randn(elt, 3, 3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 1) isa Ones{real(elt)} + @test arguments(d, 2) ≈ svd_vals(arguments(a, 2)) + end - a = randn(3, 3) ⊗ Eye(3) - d = @constinferred svd_vals(a) - d′ = svd_vals(Matrix(a)) - @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) - @test arguments(d, 2) isa Ones - @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) + for elt in (Float32, ComplexF32) + a = randn(elt, 3, 3) ⊗ Eye{elt}(3) + d = @constinferred svd_vals(a) + d′ = svd_vals(Matrix(a)) + @test sort(Vector(d); by=abs) ≈ sort(d′; by=abs) + @test arguments(d, 2) isa Ones{real(elt)} + @test arguments(d, 1) ≈ svd_vals(arguments(a, 1)) + end - a = Eye(3) ⊗ Eye(3) - d = @constinferred svd_vals(a) - @test d == Ones(3) ⊗ Ones(3) - @test arguments(d, 1) isa Ones - @test arguments(d, 2) isa Ones + for elt in (Float32, ComplexF32) + a = Eye{elt}(3) ⊗ Eye{elt}(3) + d = @constinferred svd_vals(a) + @test d == Ones(3) ⊗ Ones(3) + @test arguments(d, 1) isa Ones{real(elt)} + @test arguments(d, 2) isa Ones{real(elt)} + end # left_null a = Eye(3) ⊗ randn(3, 3)