Skip to content

Commit 4d8a9b4

Browse files
authored
Orthogonalize PlusOperator and TimesOperator unwrapping (#120)
* Orthogonalize plusoperator and timesoperator unwrap * consolidate multiplication * Add tests
1 parent 48c69f8 commit 4d8a9b4

File tree

3 files changed

+42
-54
lines changed

3 files changed

+42
-54
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunBase"
22
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
3-
version = "0.5.14"
3+
version = "0.5.15"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/Operators/general/algebra.jl

Lines changed: 21 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ end
1818

1919
Base.size(P::PlusOperator,k::Integer) = size(first(P.ops),k)
2020

21+
bandwidthsmax(ops) = mapreduce(bandwidths, (t1,t2) -> max.(t1, t2), ops, init = (-720, -720) #= approximate (-∞,-∞) =#)
2122

22-
PlusOperator(opsin::Vector{Operator{T}},bi::Tuple{UT,VT}) where {T,UT,VT} =
23+
PlusOperator(opsin::Vector{Operator{T}},bi::Tuple{Any,Any} = bandwidthsmax(opsin)) where {T} =
2324
PlusOperator{T,typeof(bi)}(opsin,bi)
2425

2526
bandwidths(P::PlusOperator) = P.bandwidths
@@ -37,14 +38,6 @@ for (OP,mn) in ((:colstart,:min),(:colstop,:max),(:rowstart,:min),(:rowstop,:max
3738
end
3839
end
3940

40-
function PlusOperator(ops::Vector)
41-
# calculate bandwidths
42-
almostneginf=-720 # approximates -∞
43-
b1 = mapreduce(first bandwidths, max, ops, init = almostneginf)
44-
b2 = mapreduce(last bandwidths, max, ops, init = almostneginf)
45-
PlusOperator(ops,(b1,b2))
46-
end
47-
4841
function convert(::Type{Operator{T}},P::PlusOperator) where T
4942
if T==eltype(P)
5043
P
@@ -69,25 +62,18 @@ domain(P::PlusOperator) = commondomain(P.ops)
6962
_promote_eltypeof(As...) = _promote_eltypeof(As)
7063
_promote_eltypeof(As::Union{Vector, Tuple}) = mapreduce(eltype, promote_type, As)
7164

72-
+(A::PlusOperator,B::PlusOperator) =
73-
promoteplus(Operator{_promote_eltypeof(A,B)}[A.ops; B.ops])
74-
+(A::PlusOperator,B::PlusOperator,C::PlusOperator) =
75-
promoteplus(Operator{_promote_eltypeof(A,B,C)}[A.ops; B.ops; C.ops])
76-
+(A::PlusOperator,B::Operator) =
77-
promoteplus(Operator{_promote_eltypeof(A,B)}[A.ops; B])
78-
+(A::PlusOperator,B::ZeroOperator) = A
79-
+(A::PlusOperator,B::Operator,C::Operator) =
80-
promoteplus(Operator{_promote_eltypeof(A,B,C)}[A.ops; B; C])
81-
+(A::Operator,B::PlusOperator) =
82-
promoteplus(Operator{_promote_eltypeof(A,B)}[A; B.ops])
83-
+(A::ZeroOperator,B::PlusOperator) = B
84-
+(A::Operator,B::Operator) =
85-
promoteplus(Operator{_promote_eltypeof(A,B)}[A,B])
86-
+(A::Operator,B::Operator,C::Operator) =
87-
promoteplus(Operator{_promote_eltypeof(A,B,C)}[A,B,C])
88-
89-
65+
_extractops(A, ::Any) = [A]
66+
_extractops(A::PlusOperator, ::typeof(+)) = A.ops
9067

68+
function +(A::Operator,B::Operator)
69+
v = Operator{_promote_eltypeof(A,B)}[_extractops(A, +); _extractops(B, +)]
70+
promoteplus(v)
71+
end
72+
# Optimization for 3-term sum
73+
function +(A::Operator,B::Operator,C::Operator)
74+
v = Operator{_promote_eltypeof(A,B,C)}[_extractops(A,+); _extractops(B, +); _extractops(C, +)]
75+
promoteplus(v)
76+
end
9177

9278
Base.stride(P::PlusOperator)=mapreduce(stride,gcd,P.ops)
9379

@@ -137,7 +123,7 @@ for TYP in (:ZeroOperator,:Operator)
137123
end
138124
end
139125
+(A::ZeroOperator,B::Operator) = B+A
140-
126+
+(Z1::ZeroOperator, Z2::ZeroOperator, Z3::ZeroOperator) = (Z1 + Z2) + Z3
141127

142128

143129

@@ -257,14 +243,12 @@ TimesOperator(ops::Vector{Operator{T}}) where {T} = TimesOperator(ops,bandwidths
257243
TimesOperator(ops::Vector{OT}) where {OT<:Operator} =
258244
TimesOperator(convert(Vector{Operator{eltype(OT)}},ops),bandwidthssum(ops))
259245

260-
TimesOperator(A::TimesOperator,B::TimesOperator) =
261-
TimesOperator(Operator{_promote_eltypeof(A,B)}[A.ops; B.ops], _bandwidthssum(A, B))
262-
TimesOperator(A::TimesOperator,B::Operator) =
263-
TimesOperator(Operator{_promote_eltypeof(A,B)}[A.ops; B], _bandwidthssum(A, B))
264-
TimesOperator(A::Operator,B::TimesOperator) =
265-
TimesOperator(Operator{_promote_eltypeof(A,B)}[A; B.ops], _bandwidthssum(A, B))
266-
TimesOperator(A::Operator,B::Operator) =
267-
TimesOperator(Operator{_promote_eltypeof(A,B)}[A,B], _bandwidthssum(A, B))
246+
_extractops(A::TimesOperator, ::typeof(*)) = A.ops
247+
248+
function TimesOperator(A::Operator,B::Operator)
249+
v = Operator{_promote_eltypeof(A,B)}[_extractops(A, *); _extractops(B, *)]
250+
TimesOperator(v, _bandwidthssum(A, B))
251+
end
268252

269253

270254
==(A::TimesOperator,B::TimesOperator)=A.ops==B.ops
@@ -492,29 +476,13 @@ for OP in (:(adjoint),:(transpose))
492476
@eval $OP(A::TimesOperator)=TimesOperator(reverse!(map($OP,A.ops)))
493477
end
494478

495-
*(A::TimesOperator,B::TimesOperator) =
496-
promotetimes(Operator{_promote_eltypeof(A, B)}[A.ops; B.ops])
497-
function *(A::TimesOperator,B::Operator)
498-
if isconstop(B)
499-
promotedomainspace(convert(Number,B)*A,domainspace(B))
500-
else
501-
promotetimes(Operator{_promote_eltypeof(A, B)}[A.ops; B])
502-
end
503-
end
504-
function *(A::Operator,B::TimesOperator)
505-
if isconstop(A)
506-
promoterangespace(convert(Number,A)*B,rangespace(A))
507-
else
508-
promotetimes(Operator{_promote_eltypeof(A, B)}[A; B.ops])
509-
end
510-
end
511479
function *(A::Operator,B::Operator)
512480
if isconstop(A)
513481
promoterangespace(convert(Number,A)*B,rangespace(A))
514482
elseif isconstop(B)
515483
promotedomainspace(convert(Number,B)*A,domainspace(B))
516484
else
517-
promotetimes(Operator{_promote_eltypeof(A, B)}[A,B])
485+
promotetimes(Operator{_promote_eltypeof(A, B)}[_extractops(A, *); _extractops(B, *)])
518486
end
519487
end
520488

test/runtests.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ end
124124
@test TimesOperator(ops).ops == [M, M, M]
125125
end
126126
end
127+
M = Multiplication(f)
128+
@test coefficients(((M * M) * M) * f) == coefficients((M * M * M) * f)
129+
T = @inferred TimesOperator(M, M)
130+
TM = @inferred TimesOperator(T, M)
131+
MT = @inferred TimesOperator(M, T)
132+
TT = @inferred TimesOperator(T, T)
133+
@test T == M * M
134+
@test TM == T * M
135+
@test MT == M * T
136+
@test T * M == M * T == M * M * M
137+
@test TT == T * T == M * M * M * M
127138
end
128139
@testset "plus operator" begin
129140
c = [1,2,3]
@@ -140,7 +151,16 @@ end
140151
op3 = op + op
141152
@test bandwidths(op3) == bandwidths(M)
142153
@test coefficients(op3 * f) == @. 2(1+t)*c^2
154+
155+
f1 = (op + op - op)*f
156+
f2 = ((op + op) - op)*f
157+
f3 = op * f
158+
@test coefficients(f1) == coefficients(f2) == coefficients(f3)
143159
end
160+
Z = ApproxFunBase.ZeroOperator()
161+
@test Z + Z == Z
162+
@test Z + Z + Z == Z
163+
@test Z + Z + Z + Z == Z
144164
end
145165
end
146166

0 commit comments

Comments
 (0)