Skip to content

Commit 5dfa945

Browse files
authored
Parameterize operators in PlusOperator/TimesOperator (#211)
* Relax operators for Plus * add parameter to timesoperator * implicit promotion im 2-arg timesop * aggressive constprop in promotetimes * specialize promotetimes for 2 args * don't promote vector in promoteplus * constant prop in operator exponentiation * move extra type param to end * version bump to v0.7.9
1 parent 1b0b043 commit 5dfa945

File tree

2 files changed

+82
-45
lines changed

2 files changed

+82
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunBase"
22
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
3-
version = "0.7.8"
3+
version = "0.7.9"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/Operators/general/algebra.jl

Lines changed: 81 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ export PlusOperator, TimesOperator, mul_coefficients
44

55

66

7-
struct PlusOperator{T,BI,SZ} <: Operator{T}
8-
ops::Vector{Operator{T}}
7+
struct PlusOperator{T,BI,SZ,O<:Operator{T}} <: Operator{T}
8+
ops::Vector{O}
99
bandwidths::BI
1010
sz :: SZ
11-
function PlusOperator{T,BI,SZ}(opsin::Vector{Operator{T}},bi::BI,sz::SZ) where {T,BI,SZ}
11+
function PlusOperator{T,BI,SZ,O}(opsin::Vector{O},bi::BI,sz::SZ) where {T,O<:Operator{T},BI,SZ}
1212
all(x -> size(x)==sz, opsin) || throw("sizes of operators are incompatible")
13-
new{T,BI,SZ}(opsin,bi,sz)
13+
new{T,BI,SZ,O}(opsin,bi,sz)
1414
end
1515
end
1616

@@ -19,11 +19,11 @@ size(P::PlusOperator, k::Integer) = P.sz[k]
1919

2020
bandwidthsmax(ops) = mapreduce(bandwidths, (t1,t2) -> max.(t1, t2), ops, init = (-720, -720) #= approximate (-∞,-∞) =#)
2121

22-
function PlusOperator(opsin::Vector{Operator{T}},
22+
function PlusOperator(opsin::Vector{O},
2323
bi::Tuple{Any,Any} = bandwidthsmax(opsin),
2424
sz::Tuple{Any,Any} = size(first(opsin)),
25-
) where {T}
26-
PlusOperator{T,typeof(bi),typeof(sz)}(opsin,bi,sz)
25+
) where {O<:Operator}
26+
PlusOperator{eltype(O),typeof(bi),typeof(sz),O}(opsin,bi,sz)
2727
end
2828

2929
bandwidths(P::PlusOperator) = P.bandwidths
@@ -41,12 +41,13 @@ for (OP,mn) in ((:colstart,:min),(:colstop,:max),(:rowstart,:min),(:rowstop,:max
4141
end
4242
end
4343

44-
function convert(::Type{Operator{T}},P::PlusOperator) where T
44+
function convert(::Type{Operator{T}}, P::PlusOperator) where T
4545
if T==eltype(P)
4646
P
4747
else
48-
PlusOperator{T,typeof(P.bandwidths),typeof(P.sz)}(
49-
Vector{Operator{T}}(P.ops),P.bandwidths,P.sz)
48+
ops = P.ops
49+
PlusOperator(ops isa AbstractVector{<:Operator{T}} ? ops : map(x -> strictconvert(Operator{T}, x), ops),
50+
P.bandwidths,P.sz)::Operator{T}
5051
end
5152
end
5253

@@ -72,12 +73,12 @@ _extractops(A, ::Any) = [A]
7273
_extractops(A::PlusOperator, ::typeof(+)) = A.ops
7374

7475
function +(A::Operator,B::Operator)
75-
v = Operator{_promote_eltypeof(A,B)}[_extractops(A, +); _extractops(B, +)]
76+
v = [_extractops(A, +); _extractops(B, +)]
7677
promoteplus(v, size(A))
7778
end
7879
# Optimization for 3-term sum
7980
function +(A::Operator,B::Operator,C::Operator)
80-
v = Operator{_promote_eltypeof(A,B,C)}[_extractops(A,+); _extractops(B, +); _extractops(C, +)]
81+
v = [_extractops(A,+); _extractops(B, +); _extractops(C, +)]
8182
promoteplus(v, size(A))
8283
end
8384

@@ -212,12 +213,12 @@ BLAS.axpy!(α,S::SubOperator{T,OP},A::AbstractMatrix) where {T,OP<:ConstantTimes
212213

213214

214215

215-
struct TimesOperator{T,BI,SZ} <: Operator{T}
216-
ops::Vector{Operator{T}}
216+
struct TimesOperator{T,BI,SZ,O<:Operator{T}} <: Operator{T}
217+
ops::Vector{O}
217218
bandwidths::BI
218219
sz::SZ
219220

220-
function TimesOperator{T,BI,SZ}(ops::Vector{Operator{T}},bi::BI,sz::SZ) where {T,BI,SZ}
221+
function TimesOperator{T,BI,SZ,O}(ops::Vector{O},bi::BI,sz::SZ) where {T,O<:Operator{T},BI,SZ}
221222
# check compatible
222223
for k=1:length(ops)-1
223224
size(ops[k],2) == size(ops[k+1],1) || throw(ArgumentError("incompatible operator sizes"))
@@ -235,7 +236,7 @@ struct TimesOperator{T,BI,SZ} <: Operator{T}
235236
newops = ops
236237
end
237238

238-
new{T,BI,SZ}(newops,bi,sz)
239+
new{T,BI,SZ,O}(newops,bi,sz)
239240
end
240241
end
241242

@@ -248,22 +249,16 @@ __bandwidthssum(A, B::NTuple{2,InfiniteCardinal{0}}) = B
248249
__bandwidthssum(A, B) = reduce((t1, t2) -> t1 .+ t2, (A, B), init = (0,0))
249250

250251
_timessize(ops) = (size(first(ops),1), size(last(ops),2))
251-
function TimesOperator(ops::Vector{Operator{T}},
252+
function TimesOperator(ops::Vector{O},
252253
bi::Tuple{Any,Any} = bandwidthssum(ops),
253-
sz::Tuple{Any,Any} = _timessize(ops)) where {T}
254-
TimesOperator{T,typeof(bi),typeof(sz)}(ops,bi,sz)
255-
end
256-
257-
function TimesOperator(ops::Vector{OT}) where {OT<:Operator}
258-
TimesOperator(strictconvert(
259-
Vector{Operator{eltype(OT)}},ops),
260-
bandwidthssum(ops), _timessize(ops))
254+
sz::Tuple{Any,Any} = _timessize(ops)) where {T,O<:Operator{T}}
255+
TimesOperator{T,typeof(bi),typeof(sz),O}(ops,bi,sz)
261256
end
262257

263258
_extractops(A::TimesOperator, ::typeof(*)) = A.ops
264259

265260
function TimesOperator(A::Operator,B::Operator)
266-
v = Operator{_promote_eltypeof(A,B)}[_extractops(A, *); _extractops(B, *)]
261+
v = [_extractops(A, *); _extractops(B, *)]
267262
TimesOperator(v, _bandwidthssum(A, B), _timessize((A,B)))
268263
end
269264

@@ -274,31 +269,65 @@ function convert(::Type{Operator{T}},P::TimesOperator) where T
274269
if T==eltype(P)
275270
P
276271
else
277-
TimesOperator(strictconvert(Vector{Operator{T}}, P.ops), bandwidths(P), size(P))
272+
ops = P.ops
273+
TimesOperator(ops isa AbstractVector{<:Operator{T}} ? ops : map(x -> strictconvert(Operator{T}, x), ops) ,
274+
bandwidths(P), size(P))
278275
end
279276
end
280277

281278

282-
283-
function promotetimes(opsin::Vector{<:Operator}, dsp = domainspace(last(opsin)),
284-
sz = _timessize(opsin))
279+
@static if VERSION > v"1.8"
280+
Base.@constprop :aggressive promotetimes(args...) = _promotetimes(args...)
281+
else
282+
promotetimes(args...) = _promotetimes(args...)
283+
end
284+
@inline function _promotetimes(opsin,
285+
dsp = domainspace(last(opsin)),
286+
sz = _timessize(opsin),
287+
anytimesop = true,
288+
)
285289

286290
@assert length(opsin) > 1 "need at least 2 operators"
291+
ops, bw = __promotetimes(opsin, dsp, anytimesop)
292+
TimesOperator(ops, bw, sz)
293+
end
294+
@inline function __promotetimes(opsin, dsp, anytimesop)
287295
ops=Vector{Operator{_promote_eltypeof(opsin)}}(undef,0)
288296
sizehint!(ops, length(opsin))
289297

290-
for k=length(opsin):-1:1
291-
if !isa(opsin[k],Conversion)
292-
op=promotedomainspace(opsin[k],dsp)
293-
dsp=rangespace(op)
294-
if isa(op,TimesOperator)
295-
append!(ops, view(op.ops, reverse(axes(op.ops,1))))
298+
for k = length(opsin):-1:1
299+
op = opsin[k]
300+
if !isa(op, Conversion)
301+
op_dsp=promotedomainspace(op, dsp)
302+
dsp=rangespace(op_dsp)
303+
if anytimesop && isa(op_dsp,TimesOperator)
304+
append!(ops, view(op_dsp.ops, reverse(axes(op_dsp.ops,1))))
296305
else
297-
push!(ops,op)
306+
push!(ops, op_dsp)
298307
end
299308
end
300309
end
301-
TimesOperator(reverse!(ops), bandwidthssum(ops), sz)
310+
reverse!(ops), bandwidthssum(ops)
311+
end
312+
@inline function __promotetimes(opsin::Tuple{Operator, Operator}, dsp, anytimesop)
313+
@assert !any(Base.Fix2(isa, TimesOperator), opsin) "TimesOperator should have been extracted already"
314+
315+
op1 = first(opsin)
316+
op2 = last(opsin)
317+
318+
if op2 isa Conversion && op1 isa Conversion
319+
op = Conversion(domainspace(op2), rangespace(op1))
320+
return [op], bandwidths(op)
321+
elseif op2 isa Conversion
322+
op = op1 rangespace(op2)
323+
return [op], bandwidths(op)
324+
elseif op1 isa Conversion
325+
op = op2 : domainspace(op1) rangespace(op2)
326+
return [op], bandwidths(op)
327+
else
328+
op1_dsp = op1:rangespace(op2)
329+
return [op1_dsp, op2], bandwidthssum((op1_dsp, op2))
330+
end
302331
end
303332

304333
domainspace(P::TimesOperator)=domainspace(last(P.ops))
@@ -486,18 +515,22 @@ end
486515

487516
for OP in (:(adjoint),:(transpose))
488517
@eval $OP(A::TimesOperator) = TimesOperator(
489-
strictconvert(Vector{Operator{eltype(A)}}, reverse!(map($OP,A.ops))),
518+
strictconvert(Vector, reverse!(map($OP,A.ops))),
490519
reverse(bandwidths(A)), reverse(size(A)))
491520
end
492521

522+
_collateops(A::TimesOperator, B::TimesOperator, ::typeof(*)) = [_extractops(A, *); _extractops(B, *)]
523+
_collateops(A::TimesOperator, B::Operator, ::typeof(*)) = [_extractops(A, *); _extractops(B, *)]
524+
_collateops(A::Operator, B::TimesOperator, ::typeof(*)) = [_extractops(A, *); _extractops(B, *)]
525+
_collateops(A::Operator, B::Operator, ::typeof(*)) = (A, B)
493526
function *(A::Operator,B::Operator)
494527
if isconstop(A)
495528
promoterangespace(strictconvert(Number,A)*B,rangespace(A))
496529
elseif isconstop(B)
497530
promotedomainspace(strictconvert(Number,B)*A,domainspace(B))
498531
else
499-
promotetimes([_extractops(A, *); _extractops(B, *)],
500-
domainspace(B), _timessize((A,B)))
532+
promotetimes(_collateops(A, B, *),
533+
domainspace(B), _timessize((A,B)), false)
501534
end
502535
end
503536

@@ -520,8 +553,12 @@ end
520553
*(A::Conversion,B::Operator) =
521554
isconstop(B) ? promotedomainspace(strictconvert(Number,B)*A,domainspace(B)) : TimesOperator(A,B)
522555

523-
^(A::Operator, p::Integer) = foldr(*, fill(A, p))
524-
556+
@inline function ^(A::Operator, p::Integer)
557+
p < 0 && return ^(inv(A), -p)
558+
p == 0 && return ConstantOperator(one(eltype(A)), domainspace(A))
559+
p <= 5 && return foldr(*, ntuple(_->A, p-1), init=A)
560+
return foldr(*, fill(A, p-2), init=A*A)
561+
end
525562

526563
+(A::Operator) = A
527564
-(A::Operator) = ConstantTimesOperator(-1,A)
@@ -601,7 +638,7 @@ function promotedomainspace(P::PlusOperator{T},sp::Space,cursp::Space) where T
601638
P
602639
else
603640
ops = [promotedomainspace(op,sp) for op in P.ops]
604-
promoteplus(Vector{Operator{_promote_eltypeof(ops)}}(ops))
641+
promoteplus(ops)
605642
end
606643
end
607644

0 commit comments

Comments
 (0)