Skip to content

Commit cef70fa

Browse files
committed
More tests
1 parent 9e66bf0 commit cef70fa

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

src/diagonalarray/diagonalarray.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,35 @@ function DiagonalArray(::UndefInitializer, unstored::Unstored)
2727
return DiagonalArray(Vector{eltype(unstored)}(undef, minimum(size(unstored))), unstored)
2828
end
2929

30+
function construct_from_length(vect::Type{<:AbstractVector}, len::Integer)
31+
if applicable(vect, len)
32+
return vect(len)
33+
elseif applicable(vect, (Base.OneTo(len),))
34+
return vect((Base.OneTo(len),))
35+
else
36+
error(lazy"Can't construct $(vect) from length.")
37+
end
38+
end
39+
3040
# This helps to support diagonals where the elements are known
3141
# from the types, for example diagonals that are `Zeros` and `Ones`.
3242
function DiagonalArray{T,N,D,U}(
3343
ax::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}
3444
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
35-
return DiagonalArray(D((Base.OneTo(minimum(length, ax)),)), Unstored(U(ax)))
45+
return DiagonalArray(construct_from_length(D, minimum(length, ax)), Unstored(U(ax)))
3646
end
3747
function DiagonalArray{T,N,D,U}(
3848
ax1::AbstractUnitRange{<:Integer}, ax_rest::Vararg{AbstractUnitRange{<:Integer}}
3949
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
4050
return DiagonalArray{T,N,D,U}((ax1, ax_rest...))
4151
end
4252
function DiagonalArray{T,N,D,U}(
43-
sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}}
53+
sz::Tuple{Integer,Vararg{Integer}}
4454
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
4555
return DiagonalArray{T,N,D,U}(Base.OneTo.(sz))
4656
end
4757
function DiagonalArray{T,N,D,U}(
48-
sz1::Integer, sz_rest::Vararg{Integer}
58+
sz1::Integer, sz_rest::Integer...
4959
) where {T,N,D<:AbstractVector{T},U<:AbstractArray{T,N}}
5060
return DiagonalArray{T,N,D,U}((sz1, sz_rest...))
5161
end
@@ -64,12 +74,12 @@ function DiagonalArray{T,N,D}(
6474
return DiagonalArray{T,N,D,Zeros{T,N}}(ax1, ax_rest...)
6575
end
6676
function DiagonalArray{T,N,D}(
67-
sz::Tuple{Integer,Vararg{AbstractUnitRange{<:Integer}}}
77+
sz::Tuple{Integer,Vararg{Integer}}
6878
) where {T,N,D<:AbstractVector{T}}
6979
return DiagonalArray{T,N,D,Zeros{T,N}}(sz)
7080
end
7181
function DiagonalArray{T,N,D}(
72-
sz1::Integer, sz_rest::Vararg{Integer}
82+
sz1::Integer, sz_rest::Integer...
7383
) where {T,N,D<:AbstractVector{T}}
7484
return DiagonalArray{T,N,D,Zeros{T,N}}(sz1, sz_rest...)
7585
end

test/test_basics.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using DiagonalArrays:
1515
diagonal,
1616
diagonaltype,
1717
diagview
18-
using FillArrays: Fill, Ones
18+
using FillArrays: Fill, Ones, Zeros
1919
using SparseArraysBase: SparseArrayDOK, sparsezeros, storedlength
2020
using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric
2121

@@ -105,6 +105,17 @@ using LinearAlgebra: Diagonal, mul!, ishermitian, isposdef, issymmetric
105105
eltype(DiagonalArray{elt,2}(undef, (2, 2)))
106106
eltype(DiagonalArray{elt,2}(undef, Base.OneTo(2), Base.OneTo(2)))
107107
eltype(DiagonalArray{elt,2}(undef, (Base.OneTo(2), Base.OneTo(2))))
108+
109+
# Special constructors for immutable diagonal.
110+
@test DiagonalMatrix(Base.OneTo(UInt32(2)))
111+
DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(Base.OneTo.((2, 2)))
112+
DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(Base.OneTo.((2, 2))...)
113+
DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}((2, 2))
114+
DiagonalArray{UInt32,2,Base.OneTo{UInt32},Zeros{UInt32,2}}(2, 2)
115+
DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(Base.OneTo.((2, 2)))
116+
DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(Base.OneTo.((2, 2))...)
117+
DiagonalArray{UInt32,2,Base.OneTo{UInt32}}((2, 2))
118+
DiagonalArray{UInt32,2,Base.OneTo{UInt32}}(2, 2)
108119
end
109120
@testset "permutedims" begin
110121
a = DiagonalArray(randn(elt, 2), (2, 3, 4))

0 commit comments

Comments
 (0)