Skip to content

Commit a38a305

Browse files
committed
Special *, /, \ for MulArray, triu/tril support
1 parent f8b1c39 commit a38a305

File tree

6 files changed

+57
-4
lines changed

6 files changed

+57
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LazyArrays"
22
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
3-
version = "0.12.3"
3+
version = "0.12.4"
44

55
[deps]
66
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/lazyapplying.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,4 +284,20 @@ end
284284
# avoid infinite-loop
285285
_base_copyto!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) where {T,N} = Base.invoke(copyto!, NTuple{2,AbstractArray{T,N}}, dest, src)
286286
_base_copyto!(dest::AbstractArray, src::AbstractArray) = Base.invoke(copyto!, NTuple{2,AbstractArray}, dest, src)
287-
@inline copyto!(dest::AbstractArray, M::Applied{LazyArrayApplyStyle}) = _base_copyto!(dest, materialize(M))
287+
@inline copyto!(dest::AbstractArray, M::Applied{LazyArrayApplyStyle}) = _base_copyto!(dest, materialize(M))
288+
289+
##
290+
# triu/tril
291+
##
292+
for tri in (:tril, :triu)
293+
for op in (:axes, :size)
294+
@eval begin
295+
$op(A::Applied{<:Any,typeof($tri)}) = $op(first(A.args))
296+
$op(A::Applied{<:Any,typeof($tri)}, j) = $op(first(A.args), j)
297+
end
298+
end
299+
@eval begin
300+
ndims(::Applied{<:Any,typeof($tri)}) = 2
301+
eltype(A::Applied{<:Any,typeof($tri)}) = eltype(first(A.args))
302+
end
303+
end

src/lazyconcat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,4 +618,4 @@ end
618618
materialize!(M::MatMulMatAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}}) =
619619
materialize!(MulAdd(M.α,M.A,Array(M.B),M.β,M.C))
620620
materialize!(M::MatMulVecAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}}) =
621-
materialize!(MulAdd(M.α,M.A,Array(M.B),M.β,M.C))
621+
materialize!(MulAdd(M.α,M.A,Array(M.B),M.β,M.C))

src/linalg/mul.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,4 +244,16 @@ end
244244

245245
@inline sub_materialize(::ApplyLayout{typeof(*)}, V) = apply(*, arguments(V)...)
246246
@inline copyto!(dest::AbstractArray{T,N}, src::SubArray{T,N,<:ApplyArray{T,N,typeof(*)}}) where {T,N} =
247-
copyto!(dest, Applied(src))
247+
copyto!(dest, Applied(src))
248+
249+
##
250+
# * specialcase
251+
##
252+
253+
for op in (:*, :\)
254+
@eval broadcasted(::DefaultArrayStyle{N}, ::typeof($op), a::Number, b::ApplyArray{<:Number,N,typeof(*)}) where N =
255+
ApplyArray(*, broadcast($op,a,first(b.args)), tail(b.args)...)
256+
end
257+
258+
broadcasted(::DefaultArrayStyle{N}, ::typeof(/), b::ApplyArray{<:Number,N,typeof(*)}, a::Number) where N =
259+
ApplyArray(*, most(b.args)..., broadcast(/,last(b.args),a))

test/multests.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,4 +1053,14 @@ end
10531053
VERSION  v"1.2" && @test @allocated(copyto!(c, V))  1000
10541054
@test all(c .=== apply(*, arguments(V)...))
10551055
end
1056+
1057+
@testset "* algebra" begin
1058+
A = ApplyArray(*,[1 2; 3 4], Vcat(Fill(1,1,3),Fill(2,1,3)))
1059+
@test 2.0A isa ApplyArray
1060+
@test 2.0\A isa ApplyArray
1061+
@test A/2 isa ApplyArray
1062+
@test (2.0A) == 2.0Array(A)
1063+
@test (2.0\A) == 2.0\Array(A)
1064+
@test A/2.0 == Array(A)/2.0
1065+
end
10561066
end

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,19 @@ end
187187

188188
# bug from BandedMartrices.jl
189189
@test LazyArrays.convexunion(7:10,9:8) == LazyArrays.convexunion(9:8,7:10) == 7:10
190+
end
191+
192+
@testset "triu/tril" begin
193+
A = ApplyArray(triu,randn(2,2))
194+
@test A isa ApplyArray{Float64}
195+
@test A == triu(A.args[1])
196+
A = ApplyArray(tril,randn(2,2))
197+
@test A isa ApplyArray{Float64}
198+
@test A == tril(A.args[1])
199+
A = ApplyArray(triu,randn(2,2),1)
200+
@test A isa ApplyArray{Float64}
201+
@test A == triu(A.args[1],1)
202+
A = ApplyArray(tril,randn(2,2),-1)
203+
@test A isa ApplyArray{Float64}
204+
@test A == tril(A.args[1],-1)
190205
end

0 commit comments

Comments
 (0)