Skip to content

Commit 0273e85

Browse files
authored
special case for multiplying Diagonal fill arrays, including Eye (#60)
* special case for multiplying Diagonal fill arrays, including Eye * Update test_layouts.jl * increase coverage * increase cov * Update test_layouts.jl * Update test_layouts.jl
1 parent 1c8b423 commit 0273e85

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.6.3"
4+
version = "0.6.4"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/diagonal.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ copy(M::Rdiv{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* inv(getinde
4747
copy(M::Rdiv{<:DiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = Diagonal(M.A.diag .* inv(getindex_value(M.B.diag)))
4848

4949

50+
5051
## bi/tridiagonal copy
5152
copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout}) = convert(Bidiagonal, M.A) * M.B
5253
copy(M::Lmul{<:DiagonalLayout,<:BidiagonalLayout}) = M.A * convert(Bidiagonal, M.B)
@@ -55,10 +56,27 @@ copy(M::Lmul{<:DiagonalLayout,<:TridiagonalLayout}) = M.A * convert(Tridiagonal,
5556
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout}) = convert(SymTridiagonal, M.A) * M.B
5657
copy(M::Lmul{<:DiagonalLayout,<:SymTridiagonalLayout}) = M.A * convert(SymTridiagonal, M.B)
5758

59+
copy(M::Lmul{DiagonalLayout{OnesLayout}}) = copy_oftype(M.B, eltype(M))
60+
copy(M::Lmul{DiagonalLayout{OnesLayout},<:DiagonalLayout}) = Diagonal(copy_oftype(diagonaldata(M.B), eltype(M)))
61+
copy(M::Lmul{<:DiagonalLayout,DiagonalLayout{OnesLayout}}) = Diagonal(copy_oftype(diagonaldata(M.A), eltype(M)))
62+
copy(M::Lmul{DiagonalLayout{OnesLayout},DiagonalLayout{OnesLayout}}) = copy_oftype(M.B, eltype(M))
63+
copy(M::Rmul{<:Any,DiagonalLayout{OnesLayout}}) = copy_oftype(M.A, eltype(M))
64+
65+
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) = getindex_value(diagonaldata(M.A)) * M.B
66+
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
67+
copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
5868

5969
copy(M::Rmul{<:BidiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
6070
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:BidiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
6171
copy(M::Rmul{<:TridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
6272
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:TridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
6373
copy(M::Rmul{<:SymTridiagonalLayout,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A * getindex_value(diagonaldata(M.B))
6474
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:SymTridiagonalLayout}) = getindex_value(diagonaldata(M.A)) * M.B
75+
76+
77+
copy(M::Rmul{<:BidiagonalLayout,DiagonalLayout{OnesLayout}}) = copy_oftype(M.A, eltype(M))
78+
copy(M::Lmul{DiagonalLayout{OnesLayout},<:BidiagonalLayout}) = copy_oftype(M.B, eltype(M))
79+
copy(M::Rmul{<:TridiagonalLayout,DiagonalLayout{OnesLayout}}) = copy_oftype(M.A, eltype(M))
80+
copy(M::Lmul{DiagonalLayout{OnesLayout},<:TridiagonalLayout}) = copy_oftype(M.B, eltype(M))
81+
copy(M::Rmul{<:SymTridiagonalLayout,DiagonalLayout{OnesLayout}}) = copy_oftype(M.A, eltype(M))
82+
copy(M::Lmul{DiagonalLayout{OnesLayout},<:SymTridiagonalLayout}) = copy_oftype(M.B, eltype(M))

test/test_layouts.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,61 @@ struct FooNumber <: Number end
310310
@test MemoryLayout(LowerTriangular(T)) isa BidiagonalLayout{FillLayout,FillLayout}
311311
@test MemoryLayout(UnitUpperTriangular(T)) isa BidiagonalLayout{FillLayout,FillLayout}
312312
@test MemoryLayout(UnitLowerTriangular(T)) isa BidiagonalLayout{FillLayout,FillLayout}
313+
314+
B = Bidiagonal(Fill(1,11), Fill(2,10), :U)
315+
@test MemoryLayout(B) isa BidiagonalLayout{FillLayout,FillLayout}
316+
317+
S = SymTridiagonal(Fill(1,11), Fill(2,10))
318+
@test MemoryLayout(S) isa SymTridiagonalLayout{FillLayout,FillLayout}
319+
320+
321+
@test ArrayLayouts.mul(Eye{Int}(11), 1:11) 1:11
322+
@test ArrayLayouts.mul(Eye(11), 1:11) isa AbstractVector{Float64}
323+
@test ArrayLayouts.mul((1:11)', Eye{Int}(11)) isa AbstractMatrix{Int}
324+
@test ArrayLayouts.mul((1:11)', Eye(11)) isa AbstractMatrix{Float64}
325+
326+
D = Diagonal(1:5)
327+
@test ArrayLayouts.mul(D, Eye{Int}(5)) ArrayLayouts.mul(Eye{Int}(5), D) D
328+
@test ArrayLayouts.mul(D, Eye(5)) == ArrayLayouts.mul(Eye(5), D) == D
329+
330+
@test ArrayLayouts.mul(Eye{Int}(11), T) isa Tridiagonal{Int,<:Fill}
331+
@test ArrayLayouts.mul(T, Eye{Int}(11)) isa Tridiagonal{Int,<:Fill}
332+
@test ArrayLayouts.mul(Eye{Int}(11), T) isa Tridiagonal{Int,<:Fill}
333+
@test ArrayLayouts.mul(T, Eye{Int}(11)) isa Tridiagonal{Int,<:Fill}
334+
@test ArrayLayouts.mul(Eye{Int}(11), B) isa Bidiagonal{Int,<:Fill}
335+
@test ArrayLayouts.mul(B, Eye{Int}(11)) isa Bidiagonal{Int,<:Fill}
336+
@test ArrayLayouts.mul(Eye{Int}(11), B) isa Bidiagonal{Int,<:Fill}
337+
@test ArrayLayouts.mul(B, Eye{Int}(11)) isa Bidiagonal{Int,<:Fill}
338+
@test ArrayLayouts.mul(Eye{Int}(11), S) isa SymTridiagonal{Int,<:Fill}
339+
@test ArrayLayouts.mul(S, Eye{Int}(11)) isa SymTridiagonal{Int,<:Fill}
340+
341+
@test ArrayLayouts.mul(Eye(11), T) isa Tridiagonal{Float64,<:Fill}
342+
@test ArrayLayouts.mul(T, Eye(11)) isa Tridiagonal{Float64,<:Fill}
343+
@test ArrayLayouts.mul(Eye(11), T) isa Tridiagonal{Float64,<:Fill}
344+
@test ArrayLayouts.mul(T, Eye(11)) isa Tridiagonal{Float64,<:Fill}
345+
@test ArrayLayouts.mul(Eye(11), B) isa Bidiagonal{Float64,<:Fill}
346+
@test ArrayLayouts.mul(B, Eye(11)) isa Bidiagonal{Float64,<:Fill}
347+
@test ArrayLayouts.mul(Eye(11), B) isa Bidiagonal{Float64,<:Fill}
348+
@test ArrayLayouts.mul(B, Eye(11)) isa Bidiagonal{Float64,<:Fill}
349+
350+
@test ArrayLayouts.mul(Eye{Int}(10), Eye{Int}(10)) Eye{Int}(10)
351+
@test ArrayLayouts.mul(Eye{Int}(10), Eye(10)) Eye(10)
352+
353+
F = Diagonal(Fill(2,11))
354+
@test ArrayLayouts.mul(F, 1:11) 2:2:22
355+
@test ArrayLayouts.mul(F, Diagonal(1:11)) ArrayLayouts.mul(Diagonal(1:11), F) Diagonal(2:2:22)
356+
@test ArrayLayouts.mul(F, T) isa Tridiagonal{Int,<:Fill}
357+
@test ArrayLayouts.mul(T, F) isa Tridiagonal{Int,<:Fill}
358+
@test ArrayLayouts.mul(F, T) isa Tridiagonal{Int,<:Fill}
359+
@test ArrayLayouts.mul(T, F) isa Tridiagonal{Int,<:Fill}
360+
@test ArrayLayouts.mul(F, B) isa Bidiagonal{Int,<:Fill}
361+
@test ArrayLayouts.mul(B, F) isa Bidiagonal{Int,<:Fill}
362+
@test ArrayLayouts.mul(F, B) isa Bidiagonal{Int,<:Fill}
363+
@test ArrayLayouts.mul(B, F) isa Bidiagonal{Int,<:Fill}
364+
@test ArrayLayouts.mul(F, S) isa SymTridiagonal{Int,<:Fill}
365+
@test ArrayLayouts.mul(S, F) isa SymTridiagonal{Int,<:Fill}
366+
367+
@test ArrayLayouts.mul((1:11)', F) isa AbstractMatrix{Int}
313368
end
314369

315370
@testset "Triangular col/rowsupport" begin

0 commit comments

Comments
 (0)