Skip to content

Commit 0cd6083

Browse files
committed
Fixes for SparseArraysBase v0.7
1 parent 26b221d commit 0cd6083

File tree

1 file changed

+50
-74
lines changed

1 file changed

+50
-74
lines changed

src/diagonalarray/diagonalarray.jl

Lines changed: 50 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,123 +1,99 @@
1-
function getzero(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
2-
return zero(eltype(a))
3-
end
1+
using FillArrays: Zeros
2+
using SparseArraysBase: Unstored
43

54
function _DiagonalArray end
65

7-
struct DiagonalArray{T,N,Diag<:AbstractVector{T},F} <: AbstractDiagonalArray{T,N}
6+
struct DiagonalArray{T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}} <:
7+
AbstractDiagonalArray{T,N}
88
diag::Diag
9-
dims::Dims{N}
10-
getunstored::F
9+
unstored::Unstored
1110
global @inline function _DiagonalArray(
12-
diag::Diag, dims::Dims{N}, getunstored::F
13-
) where {T,N,Diag<:AbstractVector{T},F}
14-
all((0), dims) || throw(ArgumentError("Invalid dimensions: $dims"))
15-
length(diag) == minimum(dims) ||
11+
diag::Diag, unstored::Unstored
12+
) where {T,N,Diag<:AbstractVector{T},Unstored<:AbstractArray{T,N}}
13+
length(diag) == minimum(size(unstored)) ||
1614
throw(ArgumentError("Length of diagonals doesn't match dimensions"))
17-
return new{T,N,Diag,F}(diag, dims, getunstored)
15+
return new{T,N,Diag,Unstored}(diag, unstored)
1816
end
1917
end
2018

21-
function DiagonalArray{T,N}(
22-
diag::AbstractVector, dims::Dims{N}; getunstored=getzero
23-
) where {T,N}
19+
SparseArraysBase.unstored(a::DiagonalArray) = a.unstored
20+
Base.size(a::DiagonalArray) = size(unstored(a))
21+
Base.axes(a::DiagonalArray) = axes(unstored(a))
22+
23+
function DiagonalArray(::UndefInitializer, unstored::Unstored)
24+
return _DiagonalArray(Vector{eltype(unstored)}(undef, ndims(unstored)), parent(unstored))
25+
end
26+
27+
function DiagonalArray{T,N}(diag::AbstractVector, unstored::AbstractArray) where {T,N}
2428
return _DiagonalArray(convert(AbstractVector{T}, diag), dims, getunstored)
2529
end
2630

27-
function DiagonalArray{T,N}(
28-
diag::AbstractVector, dims::Vararg{Int,N}; kwargs...
29-
) where {T,N}
30-
return DiagonalArray{T,N}(diag, dims; kwargs...)
31+
function DiagonalArray{T,N}(diag::AbstractVector, dims::Dims{N}) where {T,N}
32+
return _DiagonalArray(convert(AbstractVector{T}, diag), Zeros{T}(dims))
33+
end
34+
35+
function DiagonalArray{T,N}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N}
36+
return DiagonalArray{T,N}(diag, dims)
3137
end
3238

33-
function DiagonalArray{T}(diag::AbstractVector, dims::Dims{N}; kwargs...) where {T,N}
34-
return DiagonalArray{T,N}(diag, dims; kwargs...)
39+
function DiagonalArray{T}(diag::AbstractVector, dims::Dims{N}) where {T,N}
40+
return DiagonalArray{T,N}(diag, dims)
3541
end
3642

37-
function DiagonalArray{T}(diag::AbstractVector, dims::Vararg{Int,N}; kwargs...) where {T,N}
38-
return DiagonalArray{T,N}(diag, dims; kwargs...)
43+
function DiagonalArray{T}(diag::AbstractVector, dims::Vararg{Int,N}) where {T,N}
44+
return DiagonalArray{T,N}(diag, dims)
3945
end
4046

41-
function DiagonalArray{<:Any,N}(
42-
diag::AbstractVector{T}, dims::Dims{N}; kwargs...
43-
) where {T,N}
44-
return DiagonalArray{T,N}(diag, dims; kwargs...)
47+
function DiagonalArray{<:Any,N}(diag::AbstractVector{T}, dims::Dims{N}) where {T,N}
48+
return DiagonalArray{T,N}(diag, dims)
4549
end
4650

47-
function DiagonalArray{<:Any,N}(
48-
diag::AbstractVector{T}, dims::Vararg{Int,N}; kwargs...
49-
) where {T,N}
50-
return DiagonalArray{T,N}(diag, dims; kwargs...)
51+
function DiagonalArray{<:Any,N}(diag::AbstractVector{T}, dims::Vararg{Int,N}) where {T,N}
52+
return DiagonalArray{T,N}(diag, dims)
5153
end
5254

53-
function DiagonalArray(diag::AbstractVector{T}, dims::Dims{N}; kwargs...) where {T,N}
54-
return DiagonalArray{T,N}(diag, dims; kwargs...)
55+
function DiagonalArray(diag::AbstractVector{T}, dims::Dims{N}) where {T,N}
56+
return DiagonalArray{T,N}(diag, dims)
5557
end
5658

57-
function DiagonalArray(diag::AbstractVector{T}, dims::Vararg{Int,N}; kwargs...) where {T,N}
58-
return DiagonalArray{T,N}(diag, dims; kwargs...)
59+
function DiagonalArray(diag::AbstractVector{T}, dims::Vararg{Int,N}) where {T,N}
60+
return DiagonalArray{T,N}(diag, dims)
5961
end
6062

6163
# Infer size from diagonal
62-
function DiagonalArray{T,N}(diag::AbstractVector; kwargs...) where {T,N}
63-
return DiagonalArray{T,N}(diag, ntuple(Returns(length(diag)), N); kwargs...)
64+
function DiagonalArray{T,N}(diag::AbstractVector) where {T,N}
65+
return DiagonalArray{T,N}(diag, ntuple(Returns(length(diag)), N))
6466
end
6567

66-
function DiagonalArray{<:Any,N}(diag::AbstractVector{T}; kwargs...) where {T,N}
67-
return DiagonalArray{T,N}(diag; kwargs...)
68+
function DiagonalArray{<:Any,N}(diag::AbstractVector{T}) where {T,N}
69+
return DiagonalArray{T,N}(diag)
6870
end
6971

7072
# undef
71-
function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}; kwargs...) where {T,N}
72-
return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims; kwargs...)
73+
function DiagonalArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N}
74+
return DiagonalArray{T,N}(Vector{T}(undef, minimum(dims)), dims)
7375
end
7476

75-
function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}; kwargs...) where {T,N}
76-
return DiagonalArray{T,N}(undef, dims; kwargs...)
77+
function DiagonalArray{T,N}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
78+
return DiagonalArray{T,N}(undef, dims)
7779
end
7880

79-
function DiagonalArray{T}(::UndefInitializer, dims::Dims{N}; kwargs...) where {T,N}
80-
return DiagonalArray{T,N}(undef, dims; kwargs...)
81+
function DiagonalArray{T}(::UndefInitializer, dims::Dims{N}) where {T,N}
82+
return DiagonalArray{T,N}(undef, dims)
8183
end
8284

8385
function DiagonalArray{T}(::UndefInitializer, dims::Vararg{Int,N}) where {T,N}
8486
return DiagonalArray{T,N}(undef, dims)
8587
end
8688

8789
# Axes version
88-
function DiagonalArray{T}(
89-
::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}; kwargs...
90-
) where {T,N}
91-
return DiagonalArray{T,N}(undef, length.(axes); kwargs...)
90+
function DiagonalArray{T}(::UndefInitializer, axes::NTuple{N,Base.OneTo{Int}}) where {T,N}
91+
return DiagonalArray{T,N}(undef, length.(axes))
9292
end
9393

94-
# Minimal `AbstractArray` interface
95-
Base.size(a::DiagonalArray) = a.dims
96-
97-
function Base.similar(a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}})
98-
function getzero(a::AbstractArray{<:Any,N}, I::Vararg{Int,N}) where {N}
99-
return convert(elt, a.getunstored(a, I...))
100-
end
101-
return DiagonalArray{elt}(undef, dims; getunstored=getzero)
94+
function SparseArraysBase.similar_sparsearray(a::DiagonalArray, unstored::Unstored)
95+
return DiagonalArray(undef, unstored)
10296
end
10397

10498
# DiagonalArrays interface.
10599
diagview(a::DiagonalArray) = a.diag
106-
107-
# Minimal `SparseArraysBase` interface
108-
## SparseArraysBase.sparse_storage(a::DiagonalArray) = a.diag
109-
110-
# `SparseArraysBase`
111-
# Defines similar when the output can't be `DiagonalArray`,
112-
# such as in `reshape`.
113-
# TODO: Put into `DiagonalArraysSparseArraysBaseExt`?
114-
# TODO: Special case 2D to output `SparseMatrixCSC`?
115-
## function SparseArraysBase.sparse_similar(
116-
## a::DiagonalArray, elt::Type, dims::Tuple{Vararg{Int}}
117-
## )
118-
## return SparseArrayDOK{elt}(undef, dims, getindex_zero_function(a))
119-
## end
120-
121-
## function SparseArraysBase.getindex_zero_function(a::DiagonalArray)
122-
## return a.zero
123-
## end

0 commit comments

Comments
 (0)