Skip to content

Commit e28021e

Browse files
authored
diagonal special cases and colsupport for subarray (#54)
* special case of permutedims for Diagonal * add permutedims test * v0.5.3 * Add diagonal special case, colsupport for SubArray * Update ArrayLayouts.jl * test on 1.6-rc1 * Increase coverage * Update test_layoutarray.jl
1 parent e9b6ce0 commit e28021e

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.5.3"
4+
version = "0.5.4"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/ArrayLayouts.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ macro layoutgetindex(Typ)
132132
ArrayLayouts.@_layoutgetindex LinearAlgebra.Hermitian{<:Any,<:$Typ}
133133
ArrayLayouts.@_layoutgetindex LinearAlgebra.Adjoint{<:Any,<:$Typ}
134134
ArrayLayouts.@_layoutgetindex LinearAlgebra.Transpose{<:Any,<:$Typ}
135+
ArrayLayouts.@_layoutgetindex LinearAlgebra.SubArray{<:Any,2,<:$Typ}
135136
end)
136137
end
137138

@@ -162,6 +163,29 @@ getindex(A::LayoutVector, kr::Colon) = layout_getindex(A, kr)
162163
getindex(A::AdjOrTrans{<:Any,<:LayoutVector}, kr::Integer, jr::Colon) = layout_getindex(A, kr, jr)
163164
getindex(A::AdjOrTrans{<:Any,<:LayoutVector}, kr::Integer, jr::AbstractVector) = layout_getindex(A, kr, jr)
164165

166+
*(A::Diagonal{<:Any,<:LayoutVector}, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B)
167+
*(A::Diagonal{<:Any,<:LayoutVector}, B::AbstractMatrix) = mul(A, B)
168+
*(A::AbstractMatrix, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B)
169+
*(A::Diagonal{<:Any,<:LayoutVector}, B::LayoutMatrix) = mul(A, B)
170+
*(A::LayoutMatrix, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B)
171+
*(A::Diagonal{<:Any,<:LayoutVector}, B::Diagonal) = mul(A, B)
172+
*(A::Diagonal, B::Diagonal{<:Any,<:LayoutVector}) = mul(A, B)
173+
174+
for Mod in (:Adjoint, :Transpose, :Symmetric, :Hermitian)
175+
@eval begin
176+
*(A::Diagonal{<:Any,<:LayoutVector}, B::$Mod{<:Any,<:LayoutMatrix}) = mul(A,B)
177+
*(A::$Mod{<:Any,<:LayoutMatrix}, B::Diagonal{<:Any,<:LayoutVector}) = mul(A,B)
178+
end
179+
end
180+
\(A::Diagonal{<:Any,<:LayoutVector}, B::Diagonal{<:Any,<:LayoutVector}) = ldiv(A, B)
181+
\(A::Diagonal{<:Any,<:LayoutVector}, B::AbstractMatrix) = ldiv(A, B)
182+
\(A::AbstractMatrix, B::Diagonal{<:Any,<:LayoutVector}) = ldiv(A, B)
183+
\(A::Diagonal{<:Any,<:LayoutVector}, B::LayoutMatrix) = ldiv(A, B)
184+
\(A::LayoutMatrix, B::Diagonal{<:Any,<:LayoutVector}) = ldiv(A, B)
185+
\(A::Diagonal{<:Any,<:LayoutVector}, B::Diagonal) = ldiv(A, B)
186+
\(A::Diagonal, B::Diagonal{<:Any,<:LayoutVector}) = ldiv(A, B)
187+
188+
165189
_copyto!(_, _, dest::AbstractArray{T,N}, src::AbstractArray{V,N}) where {T,V,N} =
166190
Base.invoke(copyto!, Tuple{AbstractArray{T,N},AbstractArray{V,N}}, dest, src)
167191

src/memorylayout.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ rowsupport(A) = rowsupport(A, axes(A,1))
632632

633633
colsupport(_, A, j) = axes(A,1)
634634

635+
635636
""""
636637
colsupport(A, j)
637638
@@ -640,6 +641,12 @@ gives an iterator containing the possible non-zero entries in the j-th column of
640641
colsupport(A, j) = colsupport(MemoryLayout(A), A, j)
641642
colsupport(A) = colsupport(A, axes(A,2))
642643

644+
# TODO: generalise to other subarrays
645+
function colsupport(A::SubArray{<:Any,N,<:Any,<:Tuple{Slice,Any}}, j) where N
646+
_, jr = parentindices(A)
647+
colsupport(parent(A), jr[j])
648+
end
649+
643650
rowsupport(::ZerosLayout, A, _) = 1:0
644651
colsupport(::ZerosLayout, A, _) = 1:0
645652

test/test_layoutarray.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,27 @@ MemoryLayout(::Type{MyVector}) = DenseColumnMajor()
192192
@test B - I A - I
193193
@test I - B I - B
194194
end
195+
196+
@testset "Diagonal" begin
197+
D = Diagonal(MyVector(randn(5)))
198+
= Diagonal(Vector(D.diag))
199+
B = randn(5,5)
200+
= MyMatrix(B)
201+
@test D*D Matrix(D)^2
202+
@test_broken D^2 D*D
203+
@test D*B Matrix(D)*B
204+
@test B*D B*Matrix(D)
205+
@test D* Matrix(D)*
206+
@test*D *Matrix(D)
207+
@test D**D
208+
209+
@test D\D I
210+
@test D\B Matrix(D)\B
211+
@test B\D B\Matrix(D)
212+
@test D\ Matrix(D)\
213+
@test\D \Matrix(D)
214+
@test D\\D
215+
end
195216
end
196217

197218
struct MyUpperTriangular{T} <: AbstractMatrix{T}

0 commit comments

Comments
 (0)