Skip to content

Commit 9c42b2a

Browse files
authored
More general diagonal, diagonaltype (#28)
1 parent 72954e9 commit 9c42b2a

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiagonalArrays"
22
uuid = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.5"
4+
version = "0.3.6"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/diaginterface/diaginterface.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ function diagview(a::AbstractArray)
8484
return @view a[diagindices(a)]
8585
end
8686

87+
using LinearAlgebra: Diagonal
88+
diagview(a::Diagonal) = a.diag
89+
8790
function getdiagindex(a::AbstractArray, i::Integer)
8891
return diagview(a)[i]
8992
end
@@ -110,7 +113,36 @@ end
110113
diagonal(v::AbstractVector) -> AbstractMatrix
111114
112115
Return a diagonal matrix from a vector `v`.
113-
This is an extension of `LinearAlgebra.Diagonal`, designed to avoid the implication of the output type.
116+
This is an extension of `LinearAlgebra.Diagonal`, designed to avoid
117+
the implication of the output type.
114118
Defaults to `Diagonal(v)`.
115119
"""
116120
diagonal(v::AbstractVector) = LinearAlgebra.Diagonal(v)
121+
122+
"""
123+
diagonal(m::AbstractMatrix) -> AbstractMatrix
124+
125+
Return a diagonal matrix from a matrix `m` where the diagonal
126+
values are copied from the diagonal of `m`.
127+
This is an extension of `LinearAlgebra.Diagonal`, designed to avoid
128+
the implication of the output type.
129+
Defaults to `diagonal(copy(diagview(m)))`, which in general is
130+
equivalent to `Diagonal(m)`.
131+
"""
132+
diagonal(m::AbstractMatrix) = diagonal(copy(diagview(m)))
133+
134+
"""
135+
diagonaltype(::AbstractVector) -> Type{<:AbstractMatrix}
136+
diagonaltype(::Type{<:AbstractVector}) -> Type{<:AbstractMatrix}
137+
diagonaltype(::AbstractMatrix) -> Type{<:AbstractMatrix}
138+
diagonaltype(::Type{<:AbstractMatrix}) -> Type{<:AbstractMatrix}
139+
140+
Return the type of diagonal matrix that would be created from a vector or matrix
141+
using the [`diagonal`](@ref) function.
142+
"""
143+
diagonaltype
144+
145+
diagonaltype(v::AbstractVector) = diagonaltype(typeof(v))
146+
diagonaltype(V::Type{<:AbstractVector}) = Base.promote_op(diagonal, V)
147+
diagonaltype(m::AbstractMatrix) = diagonaltype(typeof(m))
148+
diagonaltype(M::Type{<:AbstractMatrix}) = Base.promote_op(diagonal, M)

test/test_basics.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DiagonalArrays:
88
diagindices,
99
diaglength,
1010
diagonal,
11+
diagonaltype,
1112
diagview
1213
using FillArrays: Fill, Ones
1314
using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength
@@ -104,8 +105,17 @@ using LinearAlgebra: Diagonal
104105
@test a_dest isa SparseArrayDOK{elt,2}
105106
end
106107
@testset "diagonal" begin
107-
@test @inferred(diagonal(rand(2))) isa AbstractMatrix
108-
@test diagonal(zeros(Int, 2)) isa Diagonal
108+
v = randn(2)
109+
d = @inferred diagonal(v)
110+
@test d isa Diagonal{eltype(v)}
111+
@test diagview(d) === v
112+
@test diagonaltype(v) === typeof(d)
113+
114+
a = randn(2, 2)
115+
d = @inferred diagonal(a)
116+
@test d isa Diagonal{eltype(v)}
117+
@test diagview(d) == diagview(a)
118+
@test diagonaltype(a) === typeof(d)
109119
end
110120
@testset "delta" begin
111121
for (a, elt′) in (

0 commit comments

Comments
 (0)