diff --git a/Project.toml b/Project.toml index 9c70da8..396e917 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiagonalArrays" uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77" authors = ["ITensor developers and contributors"] -version = "0.3.13" +version = "0.3.14" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/diagonalarray/diagonalarray.jl b/src/diagonalarray/diagonalarray.jl index f3518c5..96e5d67 100644 --- a/src/diagonalarray/diagonalarray.jl +++ b/src/diagonalarray/diagonalarray.jl @@ -120,6 +120,13 @@ function DiagonalArray{<:Any,N}(diag::AbstractVector{T}, dims::Vararg{Int,N}) wh return DiagonalArray{T,N}(diag, dims) end +function DiagonalArray{<:Any,N}( + diag::AbstractVector{T}, + ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, +) where {T,N} + return DiagonalArray{T,N}(diag, ax) +end + function DiagonalArray(diag::AbstractVector{T}, dims::Dims{N}) where {T,N} return DiagonalArray{T,N}(diag, dims) end @@ -163,6 +170,10 @@ function Base.similar(a::DiagonalArray, unstored::Unstored) return DiagonalArray(undef, unstored) end +# This definition is helpful for immutable diagonals +# such as FillArrays. +Base.copy(a::DiagonalArray) = DiagonalArray(copy(diagview(a)), axes(a)) + # DiagonalArrays interface. diagview(a::DiagonalArray) = a.diag diff --git a/src/diagonalarray/diagonalmatrix.jl b/src/diagonalarray/diagonalmatrix.jl index 873410d..9c77c25 100644 --- a/src/diagonalarray/diagonalmatrix.jl +++ b/src/diagonalarray/diagonalmatrix.jl @@ -3,3 +3,63 @@ const DiagonalMatrix{T,Diag,Zero} = DiagonalArray{T,2,Diag,Zero} function DiagonalMatrix(diag::AbstractVector) return DiagonalArray{<:Any,2}(diag) end +function DiagonalMatrix(diag::AbstractVector, ax::Tuple) + return DiagonalArray{<:Any,2}(diag, ax) +end + +# LinearAlgebra + +using LinearAlgebra: LinearAlgebra + +function mul_diagviews(a1, a2) + # TODO: Compare that duals are equal, or define a function to overload. + axes(a1, 2) == axes(a2, 1) || throw( + DimensionMismatch( + lazy"Incompatible dimensions for multiplication: $(axes(a1)) and $(axes(a2))" + ), + ) + d1 = diagview(a1) + d2 = diagview(a2) + l = min(length(d1), length(d2)) + d1′ = view(d1, Base.OneTo(l)) + d2′ = view(d2, Base.OneTo(l)) + return (d1′, d2′) +end + +function mul!_diagviews(a_dest, a1, a2) + axes(a_dest, 1) == axes(a1, 1) || throw( + DimensionMismatch( + lazy"Incompatible dimensions for multiplication: $(axes(a_dest)) and $(axes(a1))" + ), + ) + axes(a_dest, 2) == axes(a2, 2) || throw( + DimensionMismatch( + lazy"Incompatible dimensions for multiplication: $(axes(a_dest)) and $(axes(a2))" + ), + ) + d_dest = diagview(a_dest) + d1, d2 = mul_diagviews(a1, a2) + return d_dest, d1, d2 +end + +function Base.:*(a1::DiagonalMatrix, a2::DiagonalMatrix) + d1, d2 = mul_diagviews(a1, a2) + # TODO: Handle the rack-deficient case, for example: + # δ(3, 2) * δ(2, 3) + # Maybe pack the diagonal with zeros or allow rank-deficient DiagonalArrays? + return DiagonalMatrix(d1 .* d2, (axes(a1, 1), axes(a2, 2))) +end +function LinearAlgebra.mul!(a_dest::DiagonalMatrix, a1::DiagonalMatrix, a2::DiagonalMatrix) + d_dest, d1, d2 = mul!_diagviews(a_dest, a1, a2) + # TODO: Handle the rack-deficient case. + d_dest .= d1 .* d2 + return a_dest +end +function LinearAlgebra.mul!( + a_dest::DiagonalMatrix, a1::DiagonalMatrix, a2::DiagonalMatrix, α::Number, β::Number +) + d_dest, d1, d2 = mul!_diagviews(a_dest, a1, a2) + # TODO: Handle the rack-deficient case. + d_dest .= d1 .* d2 .* α .+ d_dest .* β + return a_dest +end diff --git a/test/test_basics.jl b/test/test_basics.jl index 8ff3ca8..a7f1c00 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -14,7 +14,7 @@ using DiagonalArrays: diagview using FillArrays: Fill, Ones using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength -using LinearAlgebra: Diagonal +using LinearAlgebra: Diagonal, mul! @testset "Test DiagonalArrays" begin @testset "DiagonalArray (eltype=$elt)" for elt in ( @@ -131,6 +131,15 @@ using LinearAlgebra: Diagonal @test storedlength(a_dest) == 2 @test a_dest isa DiagonalMatrix{elt} + a_dest = DiagonalArray{elt}(undef, (2, 4)) + mul!(a_dest, a1, a2) + @test Array(a_dest) ≈ Array(a1) * Array(a2) + + a_dest = DiagonalArray(randn(elt, 2), (2, 4)) + a_dest′ = copy(a_dest) + mul!(a_dest′, a1, a2, 2, 3) + @test Array(a_dest′) ≈ Array(a1) * Array(a2) * 2 + Array(a_dest) * 3 + # TODO: Make generic to GPU, use `allocate_randn`? a2 = randn(elt, (3, 4)) a_dest = a1 * a2 @@ -195,6 +204,7 @@ using LinearAlgebra: Diagonal @test a == DiagonalArray(ones(2), (2, 2)) @test diagview(a) == ones(2) @test diagview(a) isa Ones{elt′} + @test copy(a) ≡ a a′ = 2a @test diagview(a′) == 2ones(2)