Skip to content

Commit 91da4bf

Browse files
authored
LinearAlgebra: diagzero for non-OneTo axes (#55252)
Currently, the off-diagonal zeros for a block-`Diagonal` matrix is computed using `diagzero`, which calls `zeros` for the sizes of the elements. This returns an `Array`, unless one specializes `diagzero` for the custom `Diagonal` matrix type. This PR defines a `zeroslike` function that dispatches on the axes of the elements, which lets packages specialize on the axes to return custom `AbstractArray`s. Choosing to specialize on the `eltype` avoids the need to specialize on the container, and allows packages to return appropriate types for custom axis types. With this, ```julia julia> LinearAlgebra.zeroslike(::Type{S}, ax::Tuple{SOneTo, Vararg{SOneTo}}) where {S<:SMatrix} = SMatrix{map(length, ax)...}(ntuple(_->zero(eltype(S)), prod(length, ax))) julia> D = Diagonal(fill(SMatrix{2,3}(1:6), 2)) 2×2 Diagonal{SMatrix{2, 3, Int64, 6}, Vector{SMatrix{2, 3, Int64, 6}}}: [1 3 5; 2 4 6] ⋅ ⋅ [1 3 5; 2 4 6] julia> D[1,2] # now an SMatrix 2×3 SMatrix{2, 3, Int64, 6} with indices SOneTo(2)×SOneTo(3): 0 0 0 0 0 0 julia> LinearAlgebra.zeroslike(::Type{S}, ax::Tuple{SOneTo, Vararg{SOneTo}}) where {S<:MMatrix} = MMatrix{map(length, ax)...}(ntuple(_->zero(eltype(S)), prod(length, ax))) julia> D = Diagonal(fill(MMatrix{2,3}(1:6), 2)) 2×2 Diagonal{MMatrix{2, 3, Int64, 6}, Vector{MMatrix{2, 3, Int64, 6}}}: [1 3 5; 2 4 6] ⋅ ⋅ [1 3 5; 2 4 6] julia> D[1,2] # now an MMatrix 2×3 MMatrix{2, 3, Int64, 6} with indices SOneTo(2)×SOneTo(3): 0 0 0 0 0 0 ``` The reason this can't be the default behavior is that we are not guaranteed that there exists a `similar` method that accepts the combination of axes. This is why we have to fall back to using the sizes, unless a specialized method is provided by a package. One positive outcome of this is that indexing into such a block-diagonal matrix will now usually be type-stable, which mitigates https://github.com/JuliaLang/julia/issues/45535 to some extent (although it doesn't resolve the issue). I've also updated the `getindex` for `Bidiagonal` to use `diagzero`, instead of the similarly defined `bidiagzero` function that it was using. Structured block matrices may now use `diagzero` uniformly to generate the zero elements.
1 parent 9c55783 commit 91da4bf

File tree

7 files changed

+51
-10
lines changed

7 files changed

+51
-10
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ Standard library changes
138138
(callable via `cholesky[!](A, RowMaximum())`) ([#54619]).
139139
* The number of default BLAS threads now respects process affinity, instead of
140140
using total number of logical threads available on the system ([#55574]).
141+
* A new function `zeroslike` is added that is used to generate the zero elements for matrix-valued banded matrices.
142+
Custom array types may specialize this function to return an appropriate result. ([#55252])
141143

142144
#### Logging
143145

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ public AbstractTriangular,
175175
peakflops,
176176
symmetric,
177177
symmetric_type,
178+
zeroslike,
178179
matprod_dest
179180

180181
const BlasFloat = Union{Float64,Float32,ComplexF64,ComplexF32}

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,14 @@ Bidiagonal(A::Bidiagonal) = A
118118
Bidiagonal{T}(A::Bidiagonal{T}) where {T} = A
119119
Bidiagonal{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A.dv, A.ev, A.uplo)
120120

121-
bidiagzero(::Bidiagonal{T}, i, j) where {T} = zero(T)
122-
function bidiagzero(A::Bidiagonal{<:AbstractMatrix}, i, j)
123-
Tel = eltype(eltype(A.dv))
121+
function diagzero(A::Bidiagonal{<:AbstractMatrix}, i, j)
122+
Tel = eltype(A)
124123
if i < j && A.uplo == 'U' #= top right zeros =#
125-
return zeros(Tel, size(A.ev[i], 1), size(A.ev[j-1], 2))
124+
return zeroslike(Tel, axes(A.ev[i], 1), axes(A.ev[j-1], 2))
126125
elseif j < i && A.uplo == 'L' #= bottom left zeros =#
127-
return zeros(Tel, size(A.ev[i-1], 1), size(A.ev[j], 2))
126+
return zeroslike(Tel, axes(A.ev[i-1], 1), axes(A.ev[j], 2))
128127
else
129-
return zeros(Tel, size(A.dv[i], 1), size(A.dv[j], 2))
128+
return zeroslike(Tel, axes(A.dv[i], 1), axes(A.dv[j], 2))
130129
end
131130
end
132131

@@ -161,7 +160,7 @@ end
161160
elseif i == j - _offdiagind(A.uplo)
162161
return @inbounds A.ev[A.uplo == 'U' ? i : j]
163162
else
164-
return bidiagzero(A, i, j)
163+
return diagzero(A, i, j)
165164
end
166165
end
167166

@@ -173,7 +172,7 @@ end
173172
# we explicitly compare the possible bands as b.band may be constant-propagated
174173
return @inbounds A.ev[b.index]
175174
else
176-
return bidiagzero(A, Tuple(_cartinds(b))...)
175+
return diagzero(A, Tuple(_cartinds(b))...)
177176
end
178177
end
179178

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,27 @@ end
185185
end
186186
r
187187
end
188-
diagzero(::Diagonal{T}, i, j) where {T} = zero(T)
189-
diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = zeros(T, size(D.diag[i], 1), size(D.diag[j], 2))
188+
"""
189+
diagzero(A::AbstractMatrix, i, j)
190+
191+
Return the appropriate zero element `A[i, j]` corresponding to a banded matrix `A`.
192+
"""
193+
diagzero(A::AbstractMatrix, i, j) = zero(eltype(A))
194+
diagzero(D::Diagonal{M}, i, j) where {M<:AbstractMatrix} =
195+
zeroslike(M, axes(D.diag[i], 1), axes(D.diag[j], 2))
196+
# dispatching on the axes permits specializing on the axis types to return something other than an Array
197+
zeroslike(M::Type, ax::Vararg{Union{AbstractUnitRange, Integer}}) = zeroslike(M, ax)
198+
"""
199+
zeroslike(::Type{M}, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) where {M<:AbstractMatrix}
200+
zeroslike(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatrix}
201+
202+
Return an appropriate zero-ed array similar to `M`, with either the axes `ax` or the size `sz`.
203+
This will be used as a structural zero element of a matrix-valued banded matrix.
204+
By default, `zeroslike` falls back to using the size along each axis to construct the array.
205+
"""
206+
zeroslike(M::Type, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = zeroslike(M, map(length, ax))
207+
zeroslike(M::Type, sz::Tuple{Integer, Vararg{Integer}}) = zeros(M, sz)
208+
zeroslike(::Type{M}, sz::Tuple{Integer, Vararg{Integer}}) where {M<:AbstractMatrix} = zeros(eltype(M), sz)
190209

191210
@inline function getindex(D::Diagonal, b::BandIndex)
192211
@boundscheck checkbounds(D, b)

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,16 @@ end
839839
B = Bidiagonal(dv, ev, :U)
840840
@test B == Matrix{eltype(B)}(B)
841841
end
842+
843+
@testset "non-standard axes" begin
844+
LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) =
845+
zeros(T, ax)
846+
847+
s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
848+
B = Bidiagonal(fill(s,4), fill(s,3), :U)
849+
@test @inferred(B[2,1]) isa typeof(s)
850+
@test all(iszero, B[2,1])
851+
end
842852
end
843853

844854
@testset "copyto!" begin

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,13 @@ end
815815
D = Diagonal(fill(S,3))
816816
@test D * fill(S,2,3)' == fill(S * S', 3, 2)
817817
@test fill(S,3,2)' * D == fill(S' * S, 2, 3)
818+
819+
@testset "indexing with non-standard-axes" begin
820+
s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
821+
D = Diagonal(fill(s,3))
822+
@test @inferred(D[1,2]) isa typeof(s)
823+
@test all(iszero, D[1,2])
824+
end
818825
end
819826

820827
@testset "Eigensystem for block diagonal (issue #30681)" begin

test/testhelpers/SizedArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,7 @@ mul!(dest::AbstractMatrix, S1::SizedMatrix, S2::SizedMatrix, α::Number, β::Num
9999
mul!(dest::AbstractVector, M::AbstractMatrix, v::SizedVector, α::Number, β::Number) =
100100
mul!(dest, M, _data(v), α, β)
101101

102+
LinearAlgebra.zeroslike(::Type{S}, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) where {S<:SizedArray} =
103+
zeros(eltype(S), ax)
104+
102105
end

0 commit comments

Comments
 (0)