@@ -4,13 +4,13 @@ export PlusOperator, TimesOperator, mul_coefficients
4
4
5
5
6
6
7
- struct PlusOperator{T,BI,SZ} <: Operator{T}
8
- ops:: Vector{Operator{T} }
7
+ struct PlusOperator{T,BI,SZ,O <: Operator{T} } <: Operator{T}
8
+ ops:: Vector{O }
9
9
bandwidths:: BI
10
10
sz :: SZ
11
- function PlusOperator {T,BI,SZ} (opsin:: Vector{Operator{T}} ,bi:: BI ,sz:: SZ ) where {T,BI,SZ}
11
+ function PlusOperator {T,BI,SZ,O } (opsin:: Vector{O} ,bi:: BI ,sz:: SZ ) where {T,O <: Operator{T} ,BI,SZ}
12
12
all (x -> size (x)== sz, opsin) || throw (" sizes of operators are incompatible" )
13
- new {T,BI,SZ} (opsin,bi,sz)
13
+ new {T,BI,SZ,O } (opsin,bi,sz)
14
14
end
15
15
end
16
16
@@ -19,11 +19,11 @@ size(P::PlusOperator, k::Integer) = P.sz[k]
19
19
20
20
bandwidthsmax (ops) = mapreduce (bandwidths, (t1,t2) -> max .(t1, t2), ops, init = (- 720 , - 720 ) #= approximate (-∞,-∞) =# )
21
21
22
- function PlusOperator (opsin:: Vector{Operator{T} } ,
22
+ function PlusOperator (opsin:: Vector{O } ,
23
23
bi:: Tuple{Any,Any} = bandwidthsmax (opsin),
24
24
sz:: Tuple{Any,Any} = size (first (opsin)),
25
- ) where {T }
26
- PlusOperator {T ,typeof(bi),typeof(sz)} (opsin,bi,sz)
25
+ ) where {O <: Operator }
26
+ PlusOperator {eltype(O) ,typeof(bi),typeof(sz),O } (opsin,bi,sz)
27
27
end
28
28
29
29
bandwidths (P:: PlusOperator ) = P. bandwidths
@@ -41,12 +41,13 @@ for (OP,mn) in ((:colstart,:min),(:colstop,:max),(:rowstart,:min),(:rowstop,:max
41
41
end
42
42
end
43
43
44
- function convert (:: Type{Operator{T}} ,P:: PlusOperator ) where T
44
+ function convert (:: Type{Operator{T}} , P:: PlusOperator ) where T
45
45
if T== eltype (P)
46
46
P
47
47
else
48
- PlusOperator {T,typeof(P.bandwidths),typeof(P.sz)} (
49
- Vector {Operator{T}} (P. ops),P. bandwidths,P. sz)
48
+ ops = P. ops
49
+ PlusOperator (ops isa AbstractVector{<: Operator{T} } ? ops : map (x -> strictconvert (Operator{T}, x), ops),
50
+ P. bandwidths,P. sz):: Operator{T}
50
51
end
51
52
end
52
53
@@ -72,12 +73,12 @@ _extractops(A, ::Any) = [A]
72
73
_extractops (A:: PlusOperator , :: typeof (+ )) = A. ops
73
74
74
75
function + (A:: Operator ,B:: Operator )
75
- v = Operator{ _promote_eltypeof (A,B)} [_extractops (A, + ); _extractops (B, + )]
76
+ v = [_extractops (A, + ); _extractops (B, + )]
76
77
promoteplus (v, size (A))
77
78
end
78
79
# Optimization for 3-term sum
79
80
function + (A:: Operator ,B:: Operator ,C:: Operator )
80
- v = Operator{ _promote_eltypeof (A,B,C)} [_extractops (A,+ ); _extractops (B, + ); _extractops (C, + )]
81
+ v = [_extractops (A,+ ); _extractops (B, + ); _extractops (C, + )]
81
82
promoteplus (v, size (A))
82
83
end
83
84
@@ -212,12 +213,12 @@ BLAS.axpy!(α,S::SubOperator{T,OP},A::AbstractMatrix) where {T,OP<:ConstantTimes
212
213
213
214
214
215
215
- struct TimesOperator{T,BI,SZ} <: Operator{T}
216
- ops:: Vector{Operator{T} }
216
+ struct TimesOperator{T,BI,SZ,O <: Operator{T} } <: Operator{T}
217
+ ops:: Vector{O }
217
218
bandwidths:: BI
218
219
sz:: SZ
219
220
220
- function TimesOperator {T,BI,SZ} (ops:: Vector{Operator{T}} ,bi:: BI ,sz:: SZ ) where {T,BI,SZ}
221
+ function TimesOperator {T,BI,SZ,O } (ops:: Vector{O} ,bi:: BI ,sz:: SZ ) where {T,O <: Operator{T} ,BI,SZ}
221
222
# check compatible
222
223
for k= 1 : length (ops)- 1
223
224
size (ops[k],2 ) == size (ops[k+ 1 ],1 ) || throw (ArgumentError (" incompatible operator sizes" ))
@@ -235,7 +236,7 @@ struct TimesOperator{T,BI,SZ} <: Operator{T}
235
236
newops = ops
236
237
end
237
238
238
- new {T,BI,SZ} (newops,bi,sz)
239
+ new {T,BI,SZ,O } (newops,bi,sz)
239
240
end
240
241
end
241
242
@@ -248,22 +249,16 @@ __bandwidthssum(A, B::NTuple{2,InfiniteCardinal{0}}) = B
248
249
__bandwidthssum (A, B) = reduce ((t1, t2) -> t1 .+ t2, (A, B), init = (0 ,0 ))
249
250
250
251
_timessize (ops) = (size (first (ops),1 ), size (last (ops),2 ))
251
- function TimesOperator (ops:: Vector{Operator{T} } ,
252
+ function TimesOperator (ops:: Vector{O } ,
252
253
bi:: Tuple{Any,Any} = bandwidthssum (ops),
253
- sz:: Tuple{Any,Any} = _timessize (ops)) where {T}
254
- TimesOperator {T,typeof(bi),typeof(sz)} (ops,bi,sz)
255
- end
256
-
257
- function TimesOperator (ops:: Vector{OT} ) where {OT<: Operator }
258
- TimesOperator (strictconvert (
259
- Vector{Operator{eltype (OT)}},ops),
260
- bandwidthssum (ops), _timessize (ops))
254
+ sz:: Tuple{Any,Any} = _timessize (ops)) where {T,O<: Operator{T} }
255
+ TimesOperator {T,typeof(bi),typeof(sz),O} (ops,bi,sz)
261
256
end
262
257
263
258
_extractops (A:: TimesOperator , :: typeof (* )) = A. ops
264
259
265
260
function TimesOperator (A:: Operator ,B:: Operator )
266
- v = Operator{ _promote_eltypeof (A,B)} [_extractops (A, * ); _extractops (B, * )]
261
+ v = [_extractops (A, * ); _extractops (B, * )]
267
262
TimesOperator (v, _bandwidthssum (A, B), _timessize ((A,B)))
268
263
end
269
264
@@ -274,31 +269,65 @@ function convert(::Type{Operator{T}},P::TimesOperator) where T
274
269
if T== eltype (P)
275
270
P
276
271
else
277
- TimesOperator (strictconvert (Vector{Operator{T}}, P. ops), bandwidths (P), size (P))
272
+ ops = P. ops
273
+ TimesOperator (ops isa AbstractVector{<: Operator{T} } ? ops : map (x -> strictconvert (Operator{T}, x), ops) ,
274
+ bandwidths (P), size (P))
278
275
end
279
276
end
280
277
281
278
282
-
283
- function promotetimes (opsin:: Vector{<:Operator} , dsp = domainspace (last (opsin)),
284
- sz = _timessize (opsin))
279
+ @static if VERSION > v " 1.8"
280
+ Base. @constprop :aggressive promotetimes (args... ) = _promotetimes (args... )
281
+ else
282
+ promotetimes (args... ) = _promotetimes (args... )
283
+ end
284
+ @inline function _promotetimes (opsin,
285
+ dsp = domainspace (last (opsin)),
286
+ sz = _timessize (opsin),
287
+ anytimesop = true ,
288
+ )
285
289
286
290
@assert length (opsin) > 1 " need at least 2 operators"
291
+ ops, bw = __promotetimes (opsin, dsp, anytimesop)
292
+ TimesOperator (ops, bw, sz)
293
+ end
294
+ @inline function __promotetimes (opsin, dsp, anytimesop)
287
295
ops= Vector {Operator{_promote_eltypeof(opsin)}} (undef,0 )
288
296
sizehint! (ops, length (opsin))
289
297
290
- for k= length (opsin): - 1 : 1
291
- if ! isa (opsin[k],Conversion)
292
- op= promotedomainspace (opsin[k],dsp)
293
- dsp= rangespace (op)
294
- if isa (op,TimesOperator)
295
- append! (ops, view (op. ops, reverse (axes (op. ops,1 ))))
298
+ for k = length (opsin): - 1 : 1
299
+ op = opsin[k]
300
+ if ! isa (op, Conversion)
301
+ op_dsp= promotedomainspace (op, dsp)
302
+ dsp= rangespace (op_dsp)
303
+ if anytimesop && isa (op_dsp,TimesOperator)
304
+ append! (ops, view (op_dsp. ops, reverse (axes (op_dsp. ops,1 ))))
296
305
else
297
- push! (ops,op )
306
+ push! (ops, op_dsp )
298
307
end
299
308
end
300
309
end
301
- TimesOperator (reverse! (ops), bandwidthssum (ops), sz)
310
+ reverse! (ops), bandwidthssum (ops)
311
+ end
312
+ @inline function __promotetimes (opsin:: Tuple{Operator, Operator} , dsp, anytimesop)
313
+ @assert ! any (Base. Fix2 (isa, TimesOperator), opsin) " TimesOperator should have been extracted already"
314
+
315
+ op1 = first (opsin)
316
+ op2 = last (opsin)
317
+
318
+ if op2 isa Conversion && op1 isa Conversion
319
+ op = Conversion (domainspace (op2), rangespace (op1))
320
+ return [op], bandwidths (op)
321
+ elseif op2 isa Conversion
322
+ op = op1 → rangespace (op2)
323
+ return [op], bandwidths (op)
324
+ elseif op1 isa Conversion
325
+ op = op2 : domainspace (op1) → rangespace (op2)
326
+ return [op], bandwidths (op)
327
+ else
328
+ op1_dsp = op1: rangespace (op2)
329
+ return [op1_dsp, op2], bandwidthssum ((op1_dsp, op2))
330
+ end
302
331
end
303
332
304
333
domainspace (P:: TimesOperator )= domainspace (last (P. ops))
@@ -486,18 +515,22 @@ end
486
515
487
516
for OP in (:(adjoint),:(transpose))
488
517
@eval $ OP (A:: TimesOperator ) = TimesOperator (
489
- strictconvert (Vector{Operator{ eltype (A)}} , reverse! (map ($ OP,A. ops))),
518
+ strictconvert (Vector, reverse! (map ($ OP,A. ops))),
490
519
reverse (bandwidths (A)), reverse (size (A)))
491
520
end
492
521
522
+ _collateops (A:: TimesOperator , B:: TimesOperator , :: typeof (* )) = [_extractops (A, * ); _extractops (B, * )]
523
+ _collateops (A:: TimesOperator , B:: Operator , :: typeof (* )) = [_extractops (A, * ); _extractops (B, * )]
524
+ _collateops (A:: Operator , B:: TimesOperator , :: typeof (* )) = [_extractops (A, * ); _extractops (B, * )]
525
+ _collateops (A:: Operator , B:: Operator , :: typeof (* )) = (A, B)
493
526
function * (A:: Operator ,B:: Operator )
494
527
if isconstop (A)
495
528
promoterangespace (strictconvert (Number,A)* B,rangespace (A))
496
529
elseif isconstop (B)
497
530
promotedomainspace (strictconvert (Number,B)* A,domainspace (B))
498
531
else
499
- promotetimes ([ _extractops (A, * ); _extractops ( B, * )] ,
500
- domainspace (B), _timessize ((A,B)))
532
+ promotetimes (_collateops (A, B, * ),
533
+ domainspace (B), _timessize ((A,B)), false )
501
534
end
502
535
end
503
536
520
553
* (A:: Conversion ,B:: Operator ) =
521
554
isconstop (B) ? promotedomainspace (strictconvert (Number,B)* A,domainspace (B)) : TimesOperator (A,B)
522
555
523
- ^ (A:: Operator , p:: Integer ) = foldr (* , fill (A, p))
524
-
556
+ @inline function ^ (A:: Operator , p:: Integer )
557
+ p < 0 && return ^ (inv (A), - p)
558
+ p == 0 && return ConstantOperator (one (eltype (A)), domainspace (A))
559
+ p <= 5 && return foldr (* , ntuple (_-> A, p- 1 ), init= A)
560
+ return foldr (* , fill (A, p- 2 ), init= A* A)
561
+ end
525
562
526
563
+ (A:: Operator ) = A
527
564
- (A:: Operator ) = ConstantTimesOperator (- 1 ,A)
@@ -601,7 +638,7 @@ function promotedomainspace(P::PlusOperator{T},sp::Space,cursp::Space) where T
601
638
P
602
639
else
603
640
ops = [promotedomainspace (op,sp) for op in P. ops]
604
- promoteplus (Vector {Operator{_promote_eltypeof( ops)}} (ops) )
641
+ promoteplus (ops)
605
642
end
606
643
end
607
644
0 commit comments