Skip to content

Commit ea10e2e

Browse files
authored
support combine_mul for QuasiArrays (#75)
* support combine_mul for QuasiArrays * fill layouts * Update multests.jl * Cache broadcasting * increase cov * Update cache.jl * Increase coverage * Update runtests.jl * v0.13
1 parent 67e9965 commit ea10e2e

14 files changed

+181
-23
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LazyArrays"
22
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3-
version = "0.12.4"
3+
version = "0.13"
44

55
[deps]
66
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -9,7 +9,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
99
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1010

1111
[compat]
12-
FillArrays = "0.7"
12+
FillArrays = "0.7,0.8"
1313
MacroTools = "0.4.5,0.5"
1414
StaticArrays = "0.8,0.9,0.10,0.11"
1515
julia = "1"

src/LazyArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix,
5757
applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, , @~, LazyArray
5858

5959
include("memorylayout.jl")
60-
include("cache.jl")
6160
include("lazyapplying.jl")
6261
include("lazybroadcasting.jl")
6362
include("linalg/linalg.jl")
63+
include("cache.jl")
6464
include("lazyconcat.jl")
6565
include("lazysetoperations.jl")
6666
include("lazyoperations.jl")

src/cache.jl

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ function getindex(A::CachedArray, I...)
7373
A.data[I...]
7474
end
7575

76+
getindex(A::CachedVector, ::Colon) = copy(A)
77+
getindex(A::CachedVector, ::Slice) = copy(A)
78+
7679
function getindex(A::CachedVector, I, J...)
7780
@boundscheck checkbounds(A, I, J...)
7881
resizedata!(A, _maximum(axes(A,1), I))
@@ -166,4 +169,52 @@ zero!(A::CachedArray{<:Any,N,<:Any,<:Zeros}) where N = zero!(A.data)
166169
####
167170

168171
cachedlayout(_, _) = UnknownLayout()
169-
MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayout(MemoryLayout(DAT), MemoryLayout(ARR))
172+
MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayout(MemoryLayout(DAT), MemoryLayout(ARR))
173+
174+
175+
176+
#####
177+
# broadcasting
178+
#
179+
# We want broadcasting for numbers with concaenations to pass through
180+
# to take advantage of special implementations of the sub-components
181+
######
182+
183+
BroadcastStyle(::Type{<:CachedArray{<:Any,N}}) where N = LazyArrayStyle{N}()
184+
185+
broadcasted(::LazyArrayStyle, op, A::CachedArray) =
186+
CachedArray(broadcast(op, paddeddata(A)), broadcast(op, A.array))
187+
188+
broadcasted(::LazyArrayStyle, op, A::CachedArray, c::Number) =
189+
CachedArray(broadcast(op, paddeddata(A), c), broadcast(op, A.array, c))
190+
broadcasted(::LazyArrayStyle, op, c::Number, A::CachedArray) =
191+
CachedArray(broadcast(op, c, paddeddata(A)), broadcast(op, c, A.array))
192+
broadcasted(::LazyArrayStyle, op, A::CachedArray, c::Ref) =
193+
CachedArray(broadcast(op, paddeddata(A), c), broadcast(op, A.array, c))
194+
broadcasted(::LazyArrayStyle, op, c::Ref, A::CachedArray) =
195+
CachedArray(broadcast(op, c, paddeddata(A)), broadcast(op, c, A.array))
196+
197+
198+
function broadcasted(::LazyArrayStyle, op, A::CachedVector, B::AbstractVector)
199+
dat = paddeddata(A)
200+
n = length(dat)
201+
m = length(B)
202+
CachedArray(broadcast(op, dat, view(B,1:n)), broadcast(op, A.array, B))
203+
end
204+
205+
function broadcasted(::LazyArrayStyle, op, A::AbstractVector, B::CachedVector)
206+
dat = paddeddata(B)
207+
n = length(dat)
208+
m = length(A)
209+
CachedArray(broadcast(op, view(A,1:n), dat), broadcast(op, A, B.array))
210+
end
211+
212+
function broadcasted(::LazyArrayStyle, op, A::CachedVector, B::CachedVector)
213+
n = max(A.datasize[1],B.datasize[1])
214+
resizedata!(A,n)
215+
resizedata!(B,n)
216+
Adat = view(paddeddata(A),1:n)
217+
Bdat = view(paddeddata(B),1:n)
218+
CachedArray(broadcast(op, Adat, Bdat), broadcast(op, A.array, B.array))
219+
end
220+

src/lazyapplying.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,16 @@ for F in (:exp, :log, :sqrt, :cos, :sin, :tan, :csc, :sec, :cot,
209209
end
210210
end
211211

212-
struct LazyLayout <: MemoryLayout end
212+
abstract type AbstractLazyLayout <: MemoryLayout end
213+
struct LazyLayout <: AbstractLazyLayout end
213214

214215

215-
MemoryLayout(::Type{<:LazyArray}) = LazyLayout()
216+
MemoryLayout(::Type{<:LazyArray}) = LazyArrayLayout()
216217

217-
transposelayout(::LazyLayout) = LazyLayout()
218-
conjlayout(::LazyLayout) = LazyLayout()
218+
transposelayout(L::LazyLayout) = L
219+
conjlayout(L::LazyLayout) = L
220+
subarraylayout(L::LazyLayout, _) = L
221+
reshapedlayout(::LazyLayout, _) = LazyLayout()
219222

220223
combine_mul_styles(::LazyLayout) = LazyArrayApplyStyle()
221224
result_mul_style(::LazyArrayApplyStyle, ::LazyArrayApplyStyle) = LazyArrayApplyStyle()

src/lazybroadcasting.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ Broadcasted(A::BroadcastArray) = instantiate(broadcasted(A.f, A.args...))
4141
axes(A::BroadcastArray) = axes(Broadcasted(A))
4242
size(A::BroadcastArray) = map(length, axes(A))
4343

44-
IndexStyle(::BroadcastArray{<:Any,1}) = IndexLinear()
4544

4645
@propagate_inbounds getindex(A::BroadcastArray, kj::Int...) = Broadcasted(A)[kj...]
4746

src/linalg/ldiv.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ struct Ldiv{StyleA, StyleB, AType, BType}
3535
B::BType
3636
end
3737

38+
Ldiv{StyleA, StyleB}(A::AType, B::BType) where {StyleA,StyleB,AType,BType} =
39+
Ldiv{StyleA,StyleB,AType,BType}(A,B)
40+
3841
Ldiv(A::AType, B::BType) where {AType,BType} =
3942
Ldiv{typeof(MemoryLayout(AType)),typeof(MemoryLayout(BType)),AType,BType}(A, B)
4043

@@ -167,4 +170,11 @@ copy(M::Applied{LdivApplyStyle}) = copy(Ldiv(M))
167170
materialize(Ldiv(A))[kj...]
168171

169172

173+
###
174+
# * layout
175+
###
170176

177+
function copy(L::Ldiv{<:Any,ApplyLayout{typeof(*)}})
178+
args = arguments(L.B)
179+
apply(*, L.A \ first(args), tail(args)...)
180+
end

src/linalg/mul.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,29 +223,48 @@ subarraylayout(::ApplyLayout{typeof(*)}, _...) = ApplyLayout{typeof(*)}()
223223

224224
call(::ApplyLayout{typeof(*)}, V::SubArray) = *
225225

226-
function arguments(::ApplyLayout{typeof(*)}, V::SubArray{<:Any,2})
226+
function _mat_mul_arguments(V)
227227
P = parent(V)
228228
kr, jr = parentindices(V)
229229
as = arguments(P)
230230
kjr = intersect.(_mul_args_rows(kr, as...), _mul_args_cols(jr, reverse(as)...))
231231
view.(as, (kr, kjr...), (kjr..., jr))
232232
end
233233

234+
234235
_vec_mul_view(a...) = view(a...)
235236
_vec_mul_view(a::AbstractVector, kr, ::Colon) = view(a, kr)
236237

237-
function arguments(::ApplyLayout{typeof(*)}, V::SubArray{<:Any,1})
238+
function _vec_mul_arguments(V)
238239
P = parent(V)
239240
kr, = parentindices(V)
240241
as = arguments(P)
241242
kjr = intersect.(_mul_args_rows(kr, as...), _mul_args_cols(Base.OneTo(1), reverse(as)...))
242243
_vec_mul_view.(as, (kr, kjr...), (kjr..., :))
243244
end
244245

246+
arguments(::ApplyLayout{typeof(*)}, V::SubArray{<:Any,2}) = _mat_mul_arguments(V)
247+
arguments(::ApplyLayout{typeof(*)}, V::SubArray{<:Any,1}) = _vec_mul_arguments(V)
248+
245249
@inline sub_materialize(::ApplyLayout{typeof(*)}, V) = apply(*, arguments(V)...)
246250
@inline copyto!(dest::AbstractArray{T,N}, src::SubArray{T,N,<:ApplyArray{T,N,typeof(*)}}) where {T,N} =
247251
copyto!(dest, Applied(src))
248252

253+
##
254+
# adoint Mul
255+
##
256+
257+
adjointlayout(::Type, ::ApplyLayout{typeof(*)}) = ApplyLayout{typeof(*)}()
258+
transposelayout(::ApplyLayout{typeof(*)}) = ApplyLayout{typeof(*)}()
259+
260+
call(::ApplyLayout{typeof(*)}, V::Adjoint) = *
261+
call(::ApplyLayout{typeof(*)}, V::Transpose) = *
262+
263+
arguments(::ApplyLayout{typeof(*)}, V::Adjoint) = reverse(adjoint.(arguments(V')))
264+
arguments(::ApplyLayout{typeof(*)}, V::Transpose) = reverse(transpose.(arguments(V')))
265+
266+
267+
249268
##
250269
# * specialcase
251270
##

src/linalg/muladd.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ scalarone(::Type{<:AbstractArray{T}}) where T = scalarone(T)
9797
scalarzero(::Type{T}) where T = zero(T)
9898
scalarzero(::Type{<:AbstractArray{T}}) where T = scalarzero(T)
9999

100+
fillzeros(::Type{T}, ax) where T = Zeros{T}(ax)
100101

101102
_αAB(M::Mul{MulAddStyle,<:Tuple{<:AbstractArray,<:AbstractArray}}, ::Type{T}) where T = tuple(scalarone(T), M.args...)
102103
_αAB(M::Mul{MulAddStyle,<:Tuple{<:Number,<:AbstractArray,<:AbstractArray}}, ::Type{T}) where T = M.args
103-
_αABβC(M::Mul, ::Type{T}) where T = tuple(_αAB(M, T)..., scalarzero(T), Zeros{T}(axes(M)))
104+
_αABβC(M::Mul, ::Type{T}) where T = tuple(_αAB(M, T)..., scalarzero(T), fillzeros(T,axes(M)))
104105

105106
_βC(M::Mul, ::Type{T}) where T = M.args
106107
_βC(M::AbstractArray, ::Type{T}) where T = (scalarone(T), M)
@@ -249,7 +250,7 @@ function default_blasmul!(α, A::AbstractMatrix, B::AbstractVector, β, C::Abstr
249250
z = zero(A[1]*B[1] + A[1]*B[1])
250251
Astride = size(A, 1) # use size, not stride, since its not pointer arithmetic
251252

252-
@inbounds for k = 1:mB
253+
@inbounds for k in colsupport(B,1)
253254
aoffs = (k-1)*Astride
254255
b = B[k]
255256
for i = 1:mA

src/memorylayout.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,14 @@ dispatch to BLAS and LAPACK routines if the memory layout is compatible.
145145
@inline MemoryLayout(::Type{<:Number}) = ScalarLayout()
146146
@inline MemoryLayout(::Type{<:DenseArray}) = DenseColumnMajor()
147147

148-
@inline MemoryLayout(::Type{<:ReinterpretArray{T,N,S,P}}) where {T,N,S,P} = reinterpretedmemorylayout(MemoryLayout(P))
149-
@inline reinterpretedmemorylayout(::MemoryLayout) = UnknownLayout()
150-
@inline reinterpretedmemorylayout(::DenseColumnMajor) = DenseColumnMajor()
148+
@inline MemoryLayout(::Type{<:ReinterpretArray{T,N,S,P}}) where {T,N,S,P} = reinterpretedlayout(MemoryLayout(P))
149+
@inline reinterpretedlayout(::MemoryLayout) = UnknownLayout()
150+
@inline reinterpretedlayout(::DenseColumnMajor) = DenseColumnMajor()
151151

152-
@inline MemoryLayout(A::Type{<:ReshapedArray{T,N,P}}) where {T,N,P} = reshapedmemorylayout(MemoryLayout(P))
153-
@inline reshapedmemorylayout(::MemoryLayout) = UnknownLayout()
154-
@inline reshapedmemorylayout(::DenseColumnMajor) = DenseColumnMajor()
152+
153+
MemoryLayout(::Type{<:ReshapedArray{T,N,A,DIMS}}) where {T,N,A,DIMS} = reshapedlayout(MemoryLayout(A), DIMS)
154+
@inline reshapedlayout(_, _) = UnknownLayout()
155+
@inline reshapedlayout(::DenseColumnMajor, _) = DenseColumnMajor()
155156

156157

157158
@inline MemoryLayout(A::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I} =
@@ -442,4 +443,10 @@ struct EyeLayout <: MemoryLayout end
442443

443444
MemoryLayout(::Type{<:AbstractFill}) = FillLayout()
444445
MemoryLayout(::Type{<:Zeros}) = ZerosLayout()
445-
diagonallayout(::ML) where ML<:AbstractFillLayout = DiagonalLayout{ML}()
446+
diagonallayout(::ML) where ML<:AbstractFillLayout = DiagonalLayout{ML}()
447+
# all sub arrays are same
448+
subarraylayout(L::AbstractFillLayout, inds::Type) = L
449+
reshapedlayout(L::AbstractFillLayout, _) = L
450+
adjointlayout(::Type, L::AbstractFillLayout) = L
451+
transposelayout(L::AbstractFillLayout) = L
452+

test/lazymultests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,5 +187,10 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data)
187187
@test MemoryLayout(typeof(Diagonal(x))) isa DiagonalLayout{LazyLayout}
188188
@test MemoryLayout(typeof(Diagonal(ApplyArray(+,x,x)))) isa DiagonalLayout{LazyLayout}
189189
@test MemoryLayout(typeof(Diagonal(1:6))) isa DiagonalLayout{UnknownLayout}
190+
191+
@test MemoryLayout(typeof(A')) isa LazyLayout
192+
@test MemoryLayout(typeof(transpose(A))) isa LazyLayout
193+
@test MemoryLayout(typeof(view(A,1:2,1:2))) isa LazyLayout
194+
@test MemoryLayout(typeof(reshape(A,4))) isa LazyLayout
190195
end
191196
end

0 commit comments

Comments
 (0)