Skip to content

Commit 7777232

Browse files
authored
Banded * Strided is always simplifiable (#343)
* Banded * Strided is always simplifiable * fix tests * Update mul.jl * Update mul.jl * Update multests.jl * v2.2 * Update mul.jl * Update multests.jl
1 parent 223d86f commit 7777232

File tree

5 files changed

+37
-15
lines changed

5 files changed

+37
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LazyArrays"
22
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3-
version = "2.1.9"
3+
version = "2.2"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

ext/LazyArraysBandedMatricesExt.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -567,8 +567,8 @@ copy(M::Mul{<:DiagonalLayout, <:BandedLazyLayouts}) = simplify(M)
567567
copy(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}, <:BandedLazyLayouts}) = copy(mulreduce(M))
568568
copy(M::Mul{<:BandedLazyLayouts, <:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = copy(mulreduce(M))
569569

570-
simplifiable(::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true)
571-
simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = Val(true)
570+
simplifiable(::Mul{<:BandedLayouts, <:DiagonalLayout{<:OnesLayout}}) = Val(true)
571+
simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLayouts}) = Val(true)
572572
copy(M::Mul{<:BandedLazyLayouts, <:DiagonalLayout{<:OnesLayout}}) = _copy_oftype(M.A, eltype(M))
573573
copy(M::Mul{<:DiagonalLayout{<:OnesLayout}, <:BandedLazyLayouts}) = _copy_oftype(M.B, eltype(M))
574574

@@ -644,8 +644,8 @@ copy(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}},
644644
simplifiable(M::Mul{<:AbstractInvLayout{<:BandedLayouts}, <:Union{PaddedColumns,PaddedLayout,AbstractStridedLayout}}) = Val(true)
645645
simplifiable(M::Mul{<:Union{PaddedRows,PaddedLayout,DualLayout{<:PaddedRows}}, <:BandedLayouts}) = Val(true)
646646
simplifiable(M::Mul{<:BandedLayouts, <:Union{PaddedColumns,PaddedLayout}}) = Val(true)
647-
simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLazyLayouts}) = Val(true)
648-
simplifiable(M::Mul{<:BandedLazyLayouts, <:AbstractStridedLayout}) = Val(true)
647+
simplifiable(M::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:BandedLayouts}) = Val(true)
648+
simplifiable(M::Mul{<:BandedLayouts, <:AbstractStridedLayout}) = Val(true)
649649

650650
copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where Lay = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))
651651
copy(L::Ldiv{ApplyBandedLayout{typeof(*)}, Lay}) where {Lay<:AbstractLazyLayout} = copy(Ldiv{ApplyLayout{typeof(*)},Lay}(L.A, L.B))

src/LazyArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import ArrayLayouts: AbstractQLayout, Dot, Dotu, Ldiv, Lmul, MatMulMatAdd, MatMu
3838
hermitianlayout, layout_getindex, layout_replace_in_print_matrix, ldivaxes, materialize,
3939
materialize!, mulreduce, reshapedlayout, rowsupport, scalarone, scalarzero, sub_materialize,
4040
sublayout, symmetriclayout, symtridiagonallayout, transposelayout, triangulardata,
41-
triangularlayout, tridiagonallayout, zero!, transtype
41+
triangularlayout, tridiagonallayout, zero!, transtype, OnesLayout
4242

4343
import FillArrays: AbstractFill, getindex_value
4444

src/linalg/mul.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,18 +371,23 @@ applylayout(::Type{typeof(*)}, ::DualLayout{Lay}, args...) where Lay = DualLayou
371371
transtype(A::MulMatrix) = transtype(first(A.args))
372372

373373
#TODO: Why not all DiagonalLayout?
374-
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
375-
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
376-
@inline simplifiable(M::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
374+
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
375+
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
376+
@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
377+
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout{<:OnesLayout}}) = Val(true)
378+
@inline simplifiable(::Mul{<:DiagonalLayout{<:OnesLayout}, <:DiagonalLayout{<:OnesLayout}}) = Val(true) # ambiguity
379+
@inline simplifiable(::Mul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
380+
@inline simplifiable(::Mul{<:Union{AbstractStridedLayout,DualLayout{<:AbstractStridedLayout}}, <:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
381+
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout}, <:AbstractStridedLayout}) = Val(true)
377382
@inline copy(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:LazyLayouts}) = copy(mulreduce(M))
378383
@inline copy(M::Mul{<:LazyLayouts,<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
379384
@inline copy(M::Mul{BroadcastLayout{typeof(*)},<:DiagonalLayout{<:AbstractFillLayout}}) = copy(mulreduce(M))
380385

381-
@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
382-
@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
383-
@inline simplifiable(M::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
384-
@inline simplifiable(M::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
385-
@inline simplifiable(M::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
386+
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
387+
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
388+
@inline simplifiable(::Mul{<:Any,<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
389+
@inline simplifiable(::Mul{<:Union{ZerosLayout,DualLayout{ZerosLayout}},<:DiagonalLayout{<:AbstractFillLayout}}) = Val(true)
390+
@inline simplifiable(::Mul{<:DiagonalLayout{<:AbstractFillLayout},<:Union{ZerosLayout,DualLayout{ZerosLayout}}}) = Val(true)
386391

387392

388393
# inv

test/multests.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, Base64, Random, BandedMatrices
1+
using Test, LinearAlgebra, LazyArrays, FillArrays, Base64, Random
2+
using BandedMatrices
3+
using StaticArrays
24
import LazyArrays: MulAdd, MemoryLayout, DenseColumnMajor, DiagonalLayout, SymTridiagonalLayout, Add, AddArray,
35
MulStyle, MulAddStyle, Applied, ApplyStyle, Lmul, ApplyArrayBroadcastStyle, DefaultArrayApplyStyle,
46
Rmul, ApplyLayout, arguments, colsupport, rowsupport, lazymaterialize
@@ -1176,8 +1178,23 @@ end
11761178

11771179
@testset "simplifiable tests" begin
11781180
A = randn(5,5)
1181+
b = randn(5)
1182+
D = Diagonal(Fill(2,5))
1183+
E = Eye(5)
11791184
@test LazyArrays.simplifiable(*, A) == Val(false)
11801185
@test LazyArrays.simplify(Applied(*, A, A)) == A*A
1186+
@test LazyArrays.simplifiable(*, A, D) == Val(true)
1187+
@test LazyArrays.simplifiable(*, D, A) == Val(true)
1188+
@test LazyArrays.simplifiable(*, b', D) == Val(true)
1189+
@test LazyArrays.simplifiable(*, D, b) == Val(true)
1190+
@test LazyArrays.simplifiable(*, A, E) == Val(true)
1191+
@test LazyArrays.simplifiable(*, E, A) == Val(true)
1192+
@test LazyArrays.simplifiable(*, b', E) == Val(true)
1193+
@test LazyArrays.simplifiable(*, E, b) == Val(true)
1194+
@test LazyArrays.simplifiable(*, E, D) == Val(true)
1195+
@test LazyArrays.simplifiable(*, D, E) == Val(true)
1196+
@test LazyArrays.simplifiable(*, D, D) == Val(true)
1197+
@test LazyArrays.simplifiable(*, E, E) == Val(true)
11811198
end
11821199
end
11831200

0 commit comments

Comments
 (0)