Skip to content

Commit 27f7e5c

Browse files
authored
col/rowsupport for triangular (#62)
* col/rowsupport for triangular * copyto! cached from padded * Padded Dot * default triangular rdiv * clean up triangular/diagonal. Adding tests * == for cached by zeros
1 parent 229e278 commit 27f7e5c

File tree

12 files changed

+397
-198
lines changed

12 files changed

+397
-198
lines changed

src/LazyArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, broadcas
3636
combine_eltypes, DefaultArrayStyle, instantiate, materialize,
3737
materialize!, eltypes
3838

39-
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf
39+
import LinearAlgebra: AbstractTriangular, AbstractQ, checksquare, pinv, fill!, tilebufsize, Abuf, Bbuf, Cbuf, dot
4040

4141
import LinearAlgebra.BLAS: BlasFloat, BlasReal, BlasComplex
4242

src/cache.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,11 @@ Base.replace_in_print_matrix(A::CachedMatrix, i::Integer, j::Integer, s::Abstrac
136136
# special for zero cache
137137
###
138138

139-
zero!(A::CachedArray{<:Any,N,<:Any,<:Zeros}) where N = zero!(A.data)
139+
zero!(A::CachedArray{<:Any,N,<:Any,<:Zeros}) where N = zero!(A.data)
140+
141+
###
142+
# MemoryLayout
143+
####
144+
145+
cachedlayout(_, _) = UnknownLayout()
146+
MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayout(MemoryLayout(DAT), MemoryLayout(ARR))

src/lazyconcat.jl

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,4 +471,41 @@ colsupport(M::Vcat, j) = first(colsupport(first(M.args),j)):(size(Vcat(most(M.ar
471471
####
472472

473473
struct PaddedLayout{L} <: MemoryLayout end
474-
applylayout(::Type{typeof(vcat)}, ::A, ::ZerosLayout) where A = PaddedLayout{A}()
474+
applylayout(::Type{typeof(vcat)}, ::A, ::ZerosLayout) where A = PaddedLayout{A}()
475+
cachedlayout(::A, ::ZerosLayout) where A = PaddedLayout{A}()
476+
477+
478+
paddeddata(A::CachedArray) = A.data
479+
paddeddata(A::Vcat) = A.args[1]
480+
481+
function ==(A::CachedVector{<:Any,<:Any,<:Zeros}, B::CachedVector{<:Any,<:Any,<:Zeros})
482+
length(A) == length(B) || return false
483+
n = max(length(A.data), length(B.data))
484+
resizedata!(A,n); resizedata!(B,n)
485+
A.data == B.data
486+
end
487+
488+
# special copyto! since `similar` of a padded returns a cached
489+
function copyto!(dest::CachedVector{T,Vector{T},<:Zeros{T,1}}, src::Vcat{<:Any,1,<:Tuple{<:AbstractVector,<:Zeros}}) where T
490+
length(src) length(dest) || throw(BoundsError())
491+
a,_ = src.args
492+
resizedata!(dest, length(a)) # make sure we are padded enough
493+
copyto!(dest.data, a)
494+
dest
495+
end
496+
497+
struct Dot{StyleA,StyleB,ATyp,BTyp}
498+
A::ATyp
499+
B::BTyp
500+
end
501+
502+
Dot(A::ATyp,B::BTyp) where {ATyp,BTyp} = Dot{typeof(MemoryLayout(ATyp)), typeof(MemoryLayout(BTyp)), ATyp, BTyp}(A, B)
503+
materialize(d::Dot{<:Any,<:Any,<:AbstractArray,<:AbstractArray}) = Base.invoke(dot, Tuple{AbstractArray,AbstractArray}, d.A, d.B)
504+
function materialize(d::Dot{<:PaddedLayout,<:PaddedLayout,<:AbstractVector{T},<:AbstractVector{V}}) where {T,V}
505+
a,b = paddeddata(d.A), paddeddata(d.B)
506+
m = min(length(a), length(b))
507+
convert(promote_type(T,V), dot(view(a,1:m), view(b,1:m)))
508+
end
509+
510+
dot(a::CachedArray, b::AbstractArray) = materialize(Dot(a,b))
511+
dot(a::LazyArray, b::AbstractArray) = materialize(Dot(a,b))

src/linalg/diagonal.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
####
2+
# Diagonal
3+
####
4+
5+
rowsupport(::Diagonal, k) = k:k
6+
colsupport(::Diagonal, j) = j:j
7+
8+
rowsupport(::DiagonalLayout, _, k) = k:k
9+
colsupport(::DiagonalLayout, _, j) = j:j
10+
11+
###
12+
# Lmul
13+
####
14+
15+
mulapplystyle(::DiagonalLayout, ::DiagonalLayout) = LmulStyle()
16+
17+
mulapplystyle(::DiagonalLayout, _) = LmulStyle()
18+
mulapplystyle(_, ::DiagonalLayout) = RmulStyle()
19+
20+
# Diagonal multiplication never changes structure
21+
similar(M::Lmul{<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes)
22+
# equivalent to rescaling
23+
function materialize!(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}})
24+
M.B .= getindex_value(M.A.diag) .* M.B
25+
M.B
26+
end
27+
28+
copy(M::Lmul{<:DiagonalLayout,<:DiagonalLayout}) = Diagonal(M.A.diag .* M.B.diag)
29+
copy(M::Lmul{<:DiagonalLayout}) = M.A.diag .* M.B
30+
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) = getindex_value(M.A.diag) .* M.B
31+
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = Diagonal(getindex_value(M.A.diag) .* M.B.diag)
32+
33+
# Diagonal multiplication never changes structure
34+
similar(M::Rmul{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes)
35+
# equivalent to rescaling
36+
function materialize!(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
37+
M.A .= M.A .* getindex_value(M.B.diag)
38+
M.A
39+
end
40+
41+
copy(M::Rmul{<:Any,<:DiagonalLayout}) = M.A .* permutedims(M.B.diag)
42+
copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* getindex_value(M.B.diag)

src/linalg/inv.jl renamed to src/linalg/ldiv.jl

Lines changed: 7 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -116,46 +116,9 @@ const BlasMatLdivVec{styleA, styleB, T<:BlasFloat} = MatLdivVec{styleA, styleB,
116116
const BlasMatLdivMat{styleA, styleB, T<:BlasFloat} = MatLdivMat{styleA, styleB, T, T}
117117

118118

119-
120-
###
121-
# Triangular
122-
###
123-
124-
@inline function copyto!(dest::AbstractArray, M::Ldiv{<:TriangularLayout})
125-
A, B = M.A, M.B
126-
dest B || (dest .= B)
127-
materialize!(Ldiv(A, dest))
128-
end
129-
130-
@inline materialize!(M::BlasMatLdivVec{<:TriangularLayout{UPLO,UNIT,<:AbstractColumnMajor},
131-
<:AbstractStridedLayout}) where {UPLO,UNIT} =
132-
BLAS.trsv!(UPLO, 'N', UNIT, triangulardata(M.A), M.B)
133-
134-
@inline materialize!(M::BlasMatLdivVec{<:TriangularLayout{'U',UNIT,<:AbstractRowMajor},
135-
<:AbstractStridedLayout}) where {UNIT} =
136-
BLAS.trsv!('L', 'T', UNIT, transpose(triangulardata(M.A)), M.B)
137-
138-
@inline materialize!(M::BlasMatLdivVec{<:TriangularLayout{'L',UNIT,<:AbstractRowMajor},
139-
<:AbstractStridedLayout}) where {UNIT} =
140-
BLAS.trsv!('U', 'T', UNIT, transpose(triangulardata(M.A)), M.B)
141-
142-
143-
@inline materialize!(M::BlasMatLdivVec{<:TriangularLayout{'U',UNIT,<:ConjLayout{<:AbstractRowMajor}},
144-
<:AbstractStridedLayout}) where {UNIT} =
145-
BLAS.trsv!('L', 'C', UNIT, triangulardata(M.A)', M.B)
146-
147-
@inline materialize!(M::BlasMatLdivVec{<:TriangularLayout{'L',UNIT,<:ConjLayout{<:AbstractRowMajor}},
148-
<:AbstractStridedLayout}) where {UNIT,T} =
149-
BLAS.trsv!('U', 'C', UNIT, triangulardata(M.A)', M.B)
150-
151-
function materialize!(M::MatLdivMat{<:TriangularLayout})
152-
A,X = M.A,M.B
153-
size(A,2) == size(X,1) || thow(DimensionMismatch("Dimensions must match"))
154-
@views for j in axes(X,2)
155-
materialize!(Ldiv(A, X[:,j]))
156-
end
157-
X
158-
end
119+
######
120+
# PInv/Inv
121+
########
159122

160123

161124
const PInvMatrix{T,Arg} = ApplyMatrix{T,typeof(pinv),<:Tuple{Arg}}
@@ -201,4 +164,7 @@ copy(M::Applied{LdivApplyStyle}) = copy(Ldiv(M))
201164
@inline materialize!(M::Applied{LdivApplyStyle}) = materialize!(Ldiv(M))
202165

203166
@propagate_inbounds getindex(A::Applied{LazyArrayApplyStyle,typeof(\)}, kj...) =
204-
materialize(Ldiv(A))[kj...]
167+
materialize(Ldiv(A))[kj...]
168+
169+
170+

src/linalg/linalg.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ include("mul.jl")
22
include("lazymul.jl")
33
include("muladd.jl")
44
include("lmul.jl")
5-
include("inv.jl")
5+
include("ldiv.jl")
66
include("add.jl")
7-
include("factorizations.jl")
7+
include("factorizations.jl")
8+
9+
include("diagonal.jl")
10+
include("triangular.jl")

src/linalg/lmul.jl

Lines changed: 0 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -77,146 +77,3 @@ end
7777
materialize!(M::Lmul) = lmul!(M.A,M.B)
7878
materialize!(M::Rmul) = rmul!(M.A,M.B)
7979

80-
81-
82-
83-
84-
85-
###
86-
# Triangular
87-
###
88-
mulapplystyle(::TriangularLayout, ::AbstractStridedLayout) = LmulStyle()
89-
mulapplystyle(::AbstractStridedLayout, ::TriangularLayout) = RmulStyle()
90-
91-
92-
93-
94-
@inline function materialize!(M::BlasMatLmulVec{<:TriangularLayout{UPLO,UNIT,<:AbstractColumnMajor},
95-
<:AbstractStridedLayout}) where {UPLO,UNIT}
96-
A,x = M.A,M.B
97-
BLAS.trmv!(UPLO, 'N', UNIT, triangulardata(A), x)
98-
end
99-
100-
@inline function materialize!(M::BlasMatLmulVec{<:TriangularLayout{'U',UNIT,<:AbstractRowMajor},
101-
<:AbstractStridedLayout}) where UNIT
102-
A,x = M.A,M.B
103-
BLAS.trmv!('L', 'T', UNIT, transpose(triangulardata(A)), x)
104-
end
105-
106-
@inline function materialize!(M::BlasMatLmulVec{<:TriangularLayout{'L',UNIT,<:AbstractRowMajor},
107-
<:AbstractStridedLayout}) where UNIT
108-
A,x = M.A,M.B
109-
BLAS.trmv!('U', 'T', UNIT, transpose(triangulardata(A)), x)
110-
end
111-
112-
@inline function materialize!(M::BlasMatLmulVec{<:TriangularLayout{'U',UNIT,<:ConjLayout{<:AbstractRowMajor}},
113-
<:AbstractStridedLayout,<:BlasComplex}) where UNIT
114-
A,x = M.A,M.B
115-
BLAS.trmv!('L', 'C', UNIT, triangulardata(A)', x)
116-
end
117-
118-
@inline function materialize!(M::BlasMatLmulVec{<:TriangularLayout{'L',UNIT,<:ConjLayout{<:AbstractRowMajor}},
119-
<:AbstractStridedLayout,<:BlasComplex}) where UNIT
120-
A,x = M.A,M.B
121-
BLAS.trmv!('U', 'C', UNIT, triangulardata(A)', x)
122-
end
123-
# Triangular * Matrix
124-
125-
@inline function materialize!(M::BlasMatLmulMat{<:TriangularLayout{UPLO,UNIT,<:AbstractColumnMajor},
126-
<:AbstractStridedLayout, T}) where {UPLO,UNIT,T<:BlasFloat}
127-
A,x = M.A,M.B
128-
BLAS.trmm!('L', UPLO, 'N', UNIT, one(T), triangulardata(A), x)
129-
end
130-
131-
@inline function materialize!(M::BlasMatLmulMat{<:TriangularLayout{'L',UNIT,<:AbstractRowMajor},
132-
<:AbstractStridedLayout, T}) where {UNIT,T<:BlasFloat}
133-
A,x = M.A,M.B
134-
BLAS.trmm!('L', 'U', 'T', UNIT, one(T), transpose(triangulardata(A)), x)
135-
end
136-
137-
@inline function materialize!(M::BlasMatLmulMat{<:TriangularLayout{'U',UNIT,<:AbstractRowMajor},
138-
<:AbstractStridedLayout, T}) where {UNIT,T<:BlasFloat}
139-
A,x = M.A,M.B
140-
BLAS.trmm!('L', 'L', 'T', UNIT, one(T), transpose(triangulardata(A)), x)
141-
end
142-
143-
@inline function materialize!(M::BlasMatLmulMat{<:TriangularLayout{'L',UNIT,<:ConjLayout{<:AbstractRowMajor}},
144-
<:AbstractStridedLayout, T}) where {UNIT,T<:BlasComplex}
145-
A,x = M.A,M.B
146-
BLAS.trmm!('L', 'U', 'C', UNIT, one(T), triangulardata(A)', x)
147-
end
148-
149-
@inline function materialize!(M::BlasMatLmulMat{<:TriangularLayout{'U',UNIT,<:ConjLayout{<:AbstractRowMajor}},
150-
<:AbstractStridedLayout, T}) where {UNIT,T<:BlasComplex}
151-
A,x = M.A,M.B
152-
BLAS.trmm!('L', 'L', 'C', UNIT, one(T), triangulardata(A)', x)
153-
end
154-
155-
156-
materialize!(M::MatLmulMat{<:TriangularLayout}) = lmul!(M.A, M.B)
157-
158-
@inline function materialize!(M::BlasMatRmulMat{<:AbstractStridedLayout,
159-
<:TriangularLayout{UPLO,UNIT,<:AbstractColumnMajor},T}) where {UPLO,UNIT,T<:BlasFloat}
160-
x,A = M.A,M.B
161-
BLAS.trmm!('R', UPLO, 'N', UNIT, one(T), triangulardata(A), x)
162-
end
163-
164-
@inline function materialize!(M::BlasMatRmulMat{<:AbstractStridedLayout,
165-
<:TriangularLayout{'L',UNIT,<:AbstractRowMajor},T}) where {UNIT,T<:BlasFloat}
166-
x,A = M.A,M.B
167-
BLAS.trmm!('R', 'U', 'T', UNIT, one(T), transpose(triangulardata(A)), x)
168-
end
169-
170-
@inline function materialize!(M::BlasMatRmulMat{<:AbstractStridedLayout,
171-
<:TriangularLayout{'U',UNIT,<:AbstractRowMajor},T}) where {UNIT,T<:BlasFloat}
172-
x,A = M.A,M.B
173-
BLAS.trmm!('R', 'L', 'T', UNIT, one(T), transpose(triangulardata(A)), x)
174-
end
175-
176-
@inline function materialize!(M::BlasMatRmulMat{<:AbstractStridedLayout,
177-
<:TriangularLayout{'L',UNIT,<:ConjLayout{<:AbstractRowMajor}},T}) where {UPLO,UNIT,T<:BlasComplex}
178-
x,A = M.A,M.B
179-
BLAS.trmm!('R', 'U', 'C', UNIT, one(T), triangulardata(A)', x)
180-
end
181-
182-
@inline function materialize!(M::BlasMatRmulMat{<:AbstractStridedLayout,
183-
<:TriangularLayout{'U',UNIT,<:ConjLayout{<:AbstractRowMajor}},T}) where {UPLO,UNIT,T<:BlasComplex}
184-
x,A = M.A,M.B
185-
BLAS.trmm!('R', 'L', 'C', UNIT, one(T), triangulardata(A)', x)
186-
end
187-
188-
189-
materialize!(M::MatRmulMat{<:AbstractStridedLayout,<:TriangularLayout}) = rmul!(M.A, M.B)
190-
191-
192-
####
193-
# Diagonal
194-
####
195-
mulapplystyle(::DiagonalLayout, ::DiagonalLayout) = LmulStyle()
196-
197-
mulapplystyle(::DiagonalLayout, _) = LmulStyle()
198-
mulapplystyle(_, ::DiagonalLayout) = RmulStyle()
199-
200-
# Diagonal multiplication never changes structure
201-
similar(M::Lmul{<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.B, T, axes)
202-
# equivalent to rescaling
203-
function materialize!(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}})
204-
M.B .= getindex_value(M.A.diag) .* M.B
205-
M.B
206-
end
207-
208-
copy(M::Lmul{<:DiagonalLayout,<:DiagonalLayout}) = Diagonal(M.A.diag .* M.B.diag)
209-
copy(M::Lmul{<:DiagonalLayout}) = M.A.diag .* M.B
210-
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout}}) = getindex_value(M.A.diag) .* M.B
211-
copy(M::Lmul{<:DiagonalLayout{<:AbstractFillLayout},<:DiagonalLayout}) = Diagonal(getindex_value(M.A.diag) .* M.B.diag)
212-
213-
# Diagonal multiplication never changes structure
214-
similar(M::Rmul{<:Any,<:DiagonalLayout}, ::Type{T}, axes) where T = similar(M.A, T, axes)
215-
# equivalent to rescaling
216-
function materialize!(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}})
217-
M.A .= M.A .* getindex_value(M.B.diag)
218-
M.A
219-
end
220-
221-
copy(M::Rmul{<:Any,<:DiagonalLayout}) = M.A .* permutedims(M.B.diag)
222-
copy(M::Rmul{<:Any,<:DiagonalLayout{<:AbstractFillLayout}}) = M.A .* getindex_value(M.B.diag)

src/linalg/mul.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,18 +129,10 @@ gives an iterator containing the possible non-zero entries in the j-th column of
129129
"""
130130
colsupport(A, j) = colsupport(MemoryLayout(typeof(A)), A, j)
131131

132-
rowsupport(::Diagonal, k) = k:k
133-
colsupport(::Diagonal, j) = j:j
134-
135-
rowsupport(::DiagonalLayout, _, k) = k:k
136-
colsupport(::DiagonalLayout, _, j) = j:j
137-
138132
rowsupport(::ZerosLayout, _1, _2) = 1:0
139133
colsupport(::ZerosLayout, _1, _2) = 1:0
140134

141135

142-
143-
144136
####
145137
# MulArray
146138
#####

0 commit comments

Comments
 (0)