Skip to content

Commit 9599f64

Browse files
authored
Support ApplyArray{typeof(-)} and views of ApplyArray{typeof(±)} (#70)
* Support - just like Add * Update add.jl * Support views of Add/Subtract * comment out copyto! * restore
1 parent b5f009a commit 9599f64

File tree

4 files changed

+257
-95
lines changed

4 files changed

+257
-95
lines changed

src/lazyapplying.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ end
1616
@inline Applied{Style}(f::F, args::Args) where {Style,F,Args<:Tuple} = Applied{Style,F,Args}(f, args)
1717
@inline Applied{Style}(A::Applied) where Style = Applied{Style}(A.f, A.args)
1818

19+
arguments(a) = a.args
20+
1921
@inline check_applied_axes(A::Applied) = nothing
2022

2123
function instantiate(A::Applied{Style}) where Style

src/linalg/add.jl

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,6 @@
66

77
const Add{Factors<:Tuple} = Applied{<:Any, typeof(+), Factors}
88

9-
size(M::Add, p::Int) = size(M)[p]
10-
axes(M::Add, p::Int) = axes(M)[p]
11-
ndims(M::Add) = ndims(first(M.args))
12-
13-
length(M::Add) = prod(size(M))
14-
size(M::Add) = length.(axes(M))
15-
axes(M::Add) = axes(first(M.args))
16-
17-
18-
eltype(M::Add) = promote_type(map(eltype,M.args)...)
19-
209
const AddArray{T,N,Factors<:Tuple} = ApplyArray{T,N,typeof(+), Factors}
2110
const AddVector{T,Factors<:Tuple} = AddArray{T,1,Factors}
2211
const AddMatrix{T,Factors<:Tuple} = AddArray{T,2,Factors}
@@ -31,8 +20,27 @@ A lazy representation of `A1 + A2 + … + AN`; i.e., a shorthand for `applied(+,
3120
Add(As...) = applied(+, As...)
3221

3322

23+
24+
for op in (:+, :-)
25+
@eval begin
26+
size(M::Applied{<:Any, typeof($op)}, p::Int) = size(M)[p]
27+
axes(M::Applied{<:Any, typeof($op)}, p::Int) = axes(M)[p]
28+
ndims(M::Applied{<:Any, typeof($op)}) = ndims(first(M.args))
29+
30+
length(M::Applied{<:Any, typeof($op)}) = prod(size(M))
31+
size(M::Applied{<:Any, typeof($op)}) = length.(axes(M))
32+
axes(M::Applied{<:Any, typeof($op)}) = axes(first(M.args))
33+
34+
eltype(M::Applied{<:Any, typeof($op)}) = promote_type(map(eltype,M.args)...)
35+
36+
combine_mul_styles(::ApplyLayout{typeof($op)}) = IdentityMulStyle()
37+
end
38+
end
39+
40+
3441
getindex(M::Add, k::Integer) = sum(getindex.(M.args, k))
3542
getindex(M::Add, k::Integer, j::Integer) = sum(getindex.(M.args, k, j))
43+
3644
getindex(M::Add, k::CartesianIndex{1}) = M[convert(Int, k)]
3745
getindex(M::Add, kj::CartesianIndex{2}) = M[kj[1], kj[2]]
3846

@@ -45,19 +53,44 @@ function zero!(A::AbstractArray{<:AbstractArray})
4553
end
4654

4755
_fill_lmul!(β, A::AbstractArray{T}) where T = iszero(β) ? zero!(A) : lmul!(β, A)
48-
combine_mul_styles(::ApplyLayout{typeof(+)}) = IdentityMulStyle()
4956
for MulAdd_ in [MatMulMatAdd, MatMulVecAdd]
5057
# `MulAdd{ApplyLayout{typeof(+)}}` cannot "win" against
5158
# `MatMulMatAdd` and `MatMulVecAdd` hence `@eval`:
52-
@eval function materialize!(M::$MulAdd_{ApplyLayout{typeof(+)}})
53-
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
54-
if C B
55-
B = copy(B)
59+
@eval begin
60+
function materialize!(M::$MulAdd_{ApplyLayout{typeof(+)}})
61+
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
62+
if C B
63+
B = copy(B)
64+
end
65+
_fill_lmul!(β, C)
66+
for A in arguments(A)
67+
C .= applied(+,applied(*,α, A,B), C)
68+
end
69+
C
5670
end
57-
_fill_lmul!(β, C)
58-
for A in Applied(A).args
59-
C .= applied(+,applied(*,α, A,B), C)
71+
function materialize!(M::$MulAdd_{ApplyLayout{typeof(-)}})
72+
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
73+
if C B
74+
B = copy(B)
75+
end
76+
_fill_lmul!(β, C)
77+
a1,a2 = arguments(A)
78+
C .= applied(+,applied(*,α, a1,B), C)
79+
C .= applied(+,applied(*,-α, a2,B), C)
80+
C
6081
end
61-
C
6282
end
6383
end
84+
85+
86+
###
87+
# views
88+
####
89+
_view(a, b::Tuple) = view(a, b...)
90+
for op in (:+, :-)
91+
@eval begin
92+
subarraylayout(a::ApplyLayout{typeof($op)}, _) = a
93+
arguments(a::SubArray{<:Any,N,<:ApplyArray{<:Any,N,typeof($op)}}) where N =
94+
_view.(arguments(parent(a)), Ref(parentindices(a)))
95+
end
96+
end

0 commit comments

Comments
 (0)