Skip to content

Commit 229e278

Browse files
authored
Improve Lmul (#61)
* Move Lmul to separate file * axis, similar, copy for Lmul * use Lmul for Diagonal * Add Rmul * Update multests.jl * Update lmul.jl * Add Rmul for Triangular * Fix UPLO in Triangular mul * Update multests.jl * v0.12, BroadcsatMatrix macro * improvements for broadcast array * Simplify ApplyStyle * Update memorylayouttests.jl * Update lazybroadcasting.jl * Update memorylayouttests.jl * improve coverage * Improve coverage of Rmul * fix broadcast row/colsupport * don't instantiate F in apply/broadcastlayout
1 parent b58b239 commit 229e278

18 files changed

+671
-249
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LazyArrays"
22
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3-
version = "0.11.1"
3+
version = "0.12"
44

55
[deps]
66
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/LazyArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ else
5252
end
5353

5454
export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix,
55-
Hcat, Vcat, Kron, BroadcastArray, cache, Ldiv, Inv, PInv, Diff, Cumsum,
55+
Hcat, Vcat, Kron, BroadcastArray, BroadcastMatrix, BroadcastVector, cache, Ldiv, Inv, PInv, Diff, Cumsum,
5656
applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, , @~, LazyArray
5757

5858
include("memorylayout.jl")

src/lazyapplying.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ array.
141141
"""
142142
abstract type LazyArray{T,N} <: AbstractArray{T,N} end
143143

144+
const LazyMatrix{T} = LazyArray{T,2}
145+
const LazyVector{T} = LazyArray{T,1}
146+
144147
struct ApplyArray{T, N, F, Args<:Tuple} <: LazyArray{T,N}
145148
f::F
146149
args::Args
@@ -214,10 +217,14 @@ result_mul_style(::LazyArrayApplyStyle, _) = LazyArrayApplyStyle()
214217
result_mul_style(_, ::LazyArrayApplyStyle) = LazyArrayApplyStyle()
215218

216219

217-
struct ApplyLayout{F, LAY} <: MemoryLayout end
220+
struct ApplyLayout{F} <: MemoryLayout end
221+
222+
applylayout(::Type{F}, args...) where F = ApplyLayout{F}()
218223

219-
MemoryLayout(M::Type{Applied{Style,F,Args}}) where {Style,F,Args} = ApplyLayout{F,tuple_type_memorylayouts(Args)}()
220-
MemoryLayout(M::Type{ApplyArray{T,N,F,Args}}) where {T,N,F,Args} = ApplyLayout{F,tuple_type_memorylayouts(Args)}()
224+
MemoryLayout(::Type{Applied{Style,F,Args}}) where {Style,F,Args} =
225+
applylayout(F, tuple_type_memorylayouts(Args)...)
226+
MemoryLayout(::Type{ApplyArray{T,N,F,Args}}) where {T,N,F,Args} =
227+
applylayout(F, tuple_type_memorylayouts(Args)...)
221228

222229
function show(io::IO, A::Applied)
223230
print(io, "Applied(", A.f)
@@ -230,7 +237,7 @@ end
230237
applybroadcaststyle(_1, _2) = DefaultArrayStyle{2}()
231238
BroadcastStyle(M::Type{<:ApplyArray}) = applybroadcaststyle(M, MemoryLayout(M))
232239

233-
Base.replace_in_print_matrix(A::ApplyMatrix, i::Integer, j::Integer, s::AbstractString) =
240+
Base.replace_in_print_matrix(A::LazyMatrix, i::Integer, j::Integer, s::AbstractString) =
234241
i in colsupport(A,j) ? s : Base.replace_with_centered_mark(s)
235242

236243
###
@@ -263,5 +270,6 @@ end
263270
@inline getindex(A::ApplyMatrix, kr::AbstractUnitRange, jr::AbstractUnitRange) = lazy_getindex(A, kr, jr)
264271

265272

266-
diagonallayout(::LazyLayout) = LazyLayout()
267-
diagonallayout(::ApplyLayout) = DiagonalLayout{LazyLayout}()
273+
diagonallayout(::LazyLayout) = DiagonalLayout{LazyLayout}()
274+
diagonallayout(::ApplyLayout) = DiagonalLayout{LazyLayout}()
275+

src/lazybroadcasting.jl

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,25 @@ BroadcastArray{T,N}(bc::Broadcasted{Style,Axes,F,Args}) where {T,N,Style,Axes,F,
1818
BroadcastArray{T}(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Tuple{Vararg{Any,N}},<:Any,<:Tuple}) where {T,N} =
1919
BroadcastArray{T,N}(bc)
2020

21+
BroadcastVector(bc::Broadcasted) = BroadcastVector{combine_eltypes(bc.f, bc.args)}(bc)
22+
BroadcastMatrix(bc::Broadcasted) = BroadcastMatrix{combine_eltypes(bc.f, bc.args)}(bc)
23+
2124
_broadcast2broadcastarray(a, b...) = tuple(a, b...)
2225
_broadcast2broadcastarray(a::Broadcasted, b...) = tuple(BroadcastArray(a), b...)
2326

2427
_BroadcastArray(bc::Broadcasted) = BroadcastArray{combine_eltypes(bc.f, bc.args)}(bc)
2528
BroadcastArray(bc::Broadcasted{S}) where S =
2629
_BroadcastArray(instantiate(Broadcasted{S}(bc.f, _broadcast2broadcastarray(bc.args...))))
27-
BroadcastArray(b::BroadcastArray) = b
30+
2831
BroadcastArray(f, A, As...) = BroadcastArray(broadcasted(f, A, As...))
32+
BroadcastMatrix(f, A...) = BroadcastMatrix(broadcasted(f, A...))
33+
BroadcastVector(f, A...) = BroadcastVector(broadcasted(f, A...))
2934

30-
Broadcasted(A::BroadcastArray) = instantiate(broadcasted(A.f, A.args...))
35+
BroadcastArray(b::BroadcastArray) = b
36+
BroadcastVector(A::BroadcastVector) = A
37+
BroadcastMatrix(A::BroadcastMatrix) = A
3138

39+
Broadcasted(A::BroadcastArray) = instantiate(broadcasted(A.f, A.args...))
3240

3341
axes(A::BroadcastArray) = axes(Broadcasted(A))
3442
size(A::BroadcastArray) = map(length, axes(A))
@@ -73,26 +81,33 @@ end
7381

7482

7583
BroadcastStyle(::Type{<:BroadcastArray{<:Any,N}}) where N = LazyArrayStyle{N}()
84+
BroadcastStyle(::Type{<:Adjoint{<:Any,<:BroadcastVector{<:Any}}}) where N = LazyArrayStyle{2}()
85+
BroadcastStyle(::Type{<:Transpose{<:Any,<:BroadcastVector{<:Any}}}) where N = LazyArrayStyle{2}()
86+
BroadcastStyle(::Type{<:Adjoint{<:Any,<:BroadcastMatrix{<:Any}}}) where N = LazyArrayStyle{2}()
87+
BroadcastStyle(::Type{<:Transpose{<:Any,<:BroadcastMatrix{<:Any}}}) where N = LazyArrayStyle{2}()
7688
BroadcastStyle(L::LazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L
7789
BroadcastStyle(::StaticArrayStyle{N}, L::LazyArrayStyle{N}) where N = L
7890

7991
"""
80-
BroadcastLayout(f, layouts)
92+
BroadcastLayout{F}()
8193
8294
is returned by `MemoryLayout(A)` if a matrix `A` is a `BroadcastArray`.
83-
`f` is a function that broadcast operation is applied and `layouts` is
84-
a tuple of `MemoryLayout` of the broadcasted arguments.
95+
`F` is the typeof function that broadcast operation is applied.
8596
"""
86-
struct BroadcastLayout{F, LAY} <: MemoryLayout end
87-
88-
tuple_type_memorylayouts(::Type{I}) where I<:Tuple = Tuple{typeof.(MemoryLayout.(I.parameters))...}
89-
tuple_type_memorylayouts(::Type{Tuple{A}}) where {A} = Tuple{typeof(MemoryLayout(A))}
90-
tuple_type_memorylayouts(::Type{Tuple{A,B}}) where {A,B} = Tuple{typeof(MemoryLayout(A)),typeof(MemoryLayout(B))}
91-
tuple_type_memorylayouts(::Type{Tuple{A,B,C}}) where {A,B,C} = Tuple{typeof(MemoryLayout(A)),typeof(MemoryLayout(B)),typeof(MemoryLayout(C))}
92-
97+
struct BroadcastLayout{F} <: MemoryLayout end
98+
99+
tuple_type_memorylayouts(::Type{I}) where I<:Tuple = MemoryLayout.(I.parameters)
100+
tuple_type_memorylayouts(::Type{Tuple{A}}) where {A} = (MemoryLayout(A),)
101+
tuple_type_memorylayouts(::Type{Tuple{A,B}}) where {A,B} = (MemoryLayout(A),MemoryLayout(B))
102+
tuple_type_memorylayouts(::Type{Tuple{A,B,C}}) where {A,B,C} = (MemoryLayout(A),MemoryLayout(B),MemoryLayout(C))
103+
104+
broadcastlayout(::Type{F}, _...) where F = BroadcastLayout{F}()
105+
broadcastlayout(::Type, ::LazyLayout...) = LazyLayout()
106+
broadcastlayout(::Type, _, ::LazyLayout) = LazyLayout()
107+
broadcastlayout(::Type, _, _, ::LazyLayout) = LazyLayout()
108+
broadcastlayout(::Type, _, _, _, ::LazyLayout) = LazyLayout()
93109
MemoryLayout(::Type{BroadcastArray{T,N,F,Args}}) where {T,N,F,Args} =
94-
BroadcastLayout{F,tuple_type_memorylayouts(Args)}()
95-
110+
broadcastlayout(F, tuple_type_memorylayouts(Args)...)
96111
## scalar-range broadcast operations ##
97112
# Ranges already support smart broadcasting
98113
for op in (+, -, big)
@@ -134,3 +149,18 @@ broadcasted(::LazyArrayStyle{N}, ::typeof(*), a::Zeros{T,N}, b::AbstractArray{V,
134149
broadcast(DefaultArrayStyle{N}(), *, a, b)
135150

136151
diagonallayout(::BroadcastLayout) = DiagonalLayout{LazyLayout}()
152+
153+
154+
###
155+
# support
156+
###
157+
158+
_broadcast_colsupport(sz, A::Number, j) = OneTo(sz[1])
159+
_broadcast_colsupport(sz, A::AbstractVector, j) = colsupport(A,j)
160+
_broadcast_colsupport(sz, A::AbstractMatrix, j) = size(A,1) == 1 ? OneTo(sz[1]) : colsupport(A,j)
161+
_broadcast_rowsupport(sz, A::Number, j) = OneTo(sz[2])
162+
_broadcast_rowsupport(sz, A::AbstractVector, j) = OneTo(sz[2])
163+
_broadcast_rowsupport(sz, A::AbstractMatrix, j) = size(A,2) == 1 ? OneTo(sz[2]) : rowsupport(A,j)
164+
165+
colsupport(::BroadcastLayout{typeof(*)}, A, j) = intersect(_broadcast_colsupport.(Ref(size(A)), A.args, j)...)
166+
rowsupport(::BroadcastLayout{typeof(*)}, A, j) = intersect(_broadcast_rowsupport.(Ref(size(A)), A.args, j)...)

src/lazyconcat.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ end
436436
# *
437437
###
438438

439-
function materialize!(M::MatMulVecAdd{<:ApplyLayout{typeof(hcat)},<:ApplyLayout{typeof(vcat)}})
439+
function materialize!(M::MatMulVecAdd{ApplyLayout{typeof(hcat)},ApplyLayout{typeof(vcat)}})
440440
α,A,B,β,C = M.α,M.A,M.B,M.β,M.C
441441
T = eltype(C)
442442
_fill_lmul!(β,C) # this is temporary until strong β = false is supported
@@ -446,7 +446,7 @@ function materialize!(M::MatMulVecAdd{<:ApplyLayout{typeof(hcat)},<:ApplyLayout{
446446
C
447447
end
448448

449-
function materialize!(M::MatMulMatAdd{<:ApplyLayout{typeof(hcat)},<:ApplyLayout{typeof(vcat)}})
449+
function materialize!(M::MatMulMatAdd{ApplyLayout{typeof(hcat)},ApplyLayout{typeof(vcat)}})
450450
α,A,B,β,C = M.α,M.A,M.B,M.β,M.C
451451
T = eltype(C)
452452
_fill_lmul!(β,C) # this is temporary until strong β = false is supported
@@ -462,4 +462,13 @@ function materialize!(M::MatMulVecAdd{<:ApplyLayout{typeof(hcat)},<:ApplyLayout{
462462

463463

464464
most(a) = reverse(tail(reverse(a)))
465-
colsupport(M::Vcat, j) = first(colsupport(first(M.args),j)):(size(Vcat(most(M.args)...),1)+last(colsupport(last(M.args),j)))
465+
colsupport(M::Vcat, j) = first(colsupport(first(M.args),j)):(size(Vcat(most(M.args)...),1)+last(colsupport(last(M.args),j)))
466+
467+
468+
469+
###
470+
# padded
471+
####
472+
473+
struct PaddedLayout{L} <: MemoryLayout end
474+
applylayout(::Type{typeof(vcat)}, ::A, ::ZerosLayout) where A = PaddedLayout{A}()

src/linalg/add.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ end
4747
_fill_lmul!(β, A::AbstractArray{T}) where T = iszero(β) ? zero!(A) : lmul!(β, A)
4848
combine_mul_styles(::ApplyLayout{typeof(+)}) = IdentityMulStyle()
4949
for MulAdd_ in [MatMulMatAdd, MatMulVecAdd]
50-
# `MulAdd{<:ApplyLayout{typeof(+)}}` cannot "win" against
50+
# `MulAdd{ApplyLayout{typeof(+)}}` cannot "win" against
5151
# `MatMulMatAdd` and `MatMulVecAdd` hence `@eval`:
52-
@eval function materialize!(M::$MulAdd_{<:ApplyLayout{typeof(+)}})
52+
@eval function materialize!(M::$MulAdd_{ApplyLayout{typeof(+)}})
5353
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
5454
if C B
5555
B = copy(B)

src/linalg/factorizations.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ function check_mul_axes(A::AbstractQ, B, C...)
1414
check_mul_axes(B, C...)
1515
end
1616

17+
copy(M::Lmul{QLayout}) = copyto!(similar(M), M)
1718

1819
function copyto!(dest::AbstractArray{T}, M::Lmul{QLayout}) where T
1920
A,B = M.A,M.B
2021
if size(dest,1) == size(B,1)
2122
copyto!(dest, B)
2223
else
2324
copyto!(view(dest,1:size(B,1),:), B)
24-
fill!(@view(dest[size(B,1)+1:end,:]), zero(T))
25+
zero!(@view(dest[size(B,1)+1:end,:]))
2526
end
2627
materialize!(Lmul(A,dest))
2728
end

src/linalg/inv.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,14 @@ size(A::PInvMatrix) = map(length, axes(A))
178178
@propagate_inbounds getindex(A::InvMatrix{T}, k::Int, j::Int) where T =
179179
(parent(A)\[Zeros(j-1); one(T); Zeros(size(A,2) - j)])[k]
180180

181-
mulapplystyle(::ApplyLayout{typeof(inv),Tuple{A}}, B) where A = ldivapplystyle(A, B)
182-
mulapplystyle(::ApplyLayout{typeof(pinv),Tuple{A}}, B) where A = ldivapplystyle(A, B)
181+
struct InvLayout{L} <: MemoryLayout end
182+
struct PInvLayout{L} <: MemoryLayout end
183+
184+
applylayout(::Type{typeof(inv)}, ::A) where A = InvLayout{A}()
185+
applylayout(::Type{typeof(pinv)}, ::A) where A = PInvLayout{A}()
186+
187+
mulapplystyle(::InvLayout{A}, B) where A = ldivapplystyle(A, B)
188+
mulapplystyle(::PInvLayout{A}, B) where A = ldivapplystyle(A, B)
183189

184190

185191
@inline function Ldiv(M::Mul)

src/linalg/lazymul.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ macro lazymul(Typ)
2828
Base.:*(A::LazyArrays.ApplyMatrix, B::$Typ, C...) = LazyArrays.apply(*,A,B,C...)
2929
end
3030
end
31+
if Typ  :BroadcastMatrix && Typ :ApplyMatrix
32+
ret = quote
33+
$ret
34+
Base.:*(A::$Typ, B::LazyArrays.BroadcastMatrix, C...) = LazyArrays.apply(*,A,B,C...)
35+
Base.:*(A::LazyArrays.BroadcastMatrix, B::$Typ, C...) = LazyArrays.apply(*,A,B,C...)
36+
end
37+
end
3138
for Struc in (:AbstractTriangular, :Diagonal)
3239
ret = quote
3340
$ret
@@ -60,6 +67,18 @@ macro lazymul(Typ)
6067
Base.:*(A::LinearAlgebra.AbstractTriangular, B::$Mod{<:Any,<:$Typ}, C...) = LazyArrays.apply(*,A,B, C...)
6168
Base.:*(A::$Mod{<:Any,<:$Typ}, B::LinearAlgebra.AbstractTriangular, C...) = LazyArrays.apply(*,A,B, C...)
6269
end
70+
if Typ :ApplyMatrix
71+
ret = quote
72+
$ret
73+
Base.:*(A::$Mod{<:Any,<:$Typ}, B::ApplyMatrix, C...) = LazyArrays.apply(*,A,B, C...)
74+
end
75+
end
76+
if Typ  :BroadcastMatrix && Typ :ApplyMatrix
77+
ret = quote
78+
$ret
79+
Base.:*(A::$Mod{<:Any,<:$Typ}, B::BroadcastMatrix, C...) = LazyArrays.apply(*,A,B, C...)
80+
end
81+
end
6382
end
6483

6584
esc(ret)

src/linalg/linalg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
include("mul.jl")
22
include("lazymul.jl")
33
include("muladd.jl")
4+
include("lmul.jl")
45
include("inv.jl")
56
include("add.jl")
67
include("factorizations.jl")

0 commit comments

Comments
 (0)