@@ -4,21 +4,20 @@ export PlusOperator, TimesOperator, mul_coefficients
4
4
5
5
6
6
7
- struct PlusOperator{T,BW,SZ,O<: Operator{T} ,BBW} <: Operator{T}
7
+ struct PlusOperator{T,BW,SZ,O<: Operator{T} ,BBW,SBBW } <: Operator{T}
8
8
ops:: Vector{O}
9
9
bandwidths:: BW
10
10
sz:: SZ
11
11
blockbandwidths:: BBW
12
+ subblockbandwidths:: SBBW
12
13
13
- function PlusOperator {T,BW,SZ,O,BBW} (opsin:: Vector{O} , bi:: BW , sz:: SZ , bbw:: BBW ) where {T,O<: Operator{T} ,BW,SZ,BBW}
14
+ function PlusOperator {T,BW,SZ,O,BBW,SBBW} (opsin:: Vector{O} , bw:: BW ,
15
+ sz:: SZ , bbw:: BBW , sbbw:: SBBW ) where {T,O<: Operator{T} ,BW,SZ,BBW,SBBW}
14
16
all (x -> size (x) == sz, opsin) || throw (" sizes of operators are incompatible" )
15
- new {T,BW,SZ,O,BBW} (opsin, bi , sz, bbw)
17
+ new {T,BW,SZ,O,BBW,SBBW } (opsin, bw , sz, bbw,sbbw )
16
18
end
17
19
end
18
20
19
- size (P:: PlusOperator ) = P. sz
20
- size (P:: PlusOperator , k:: Integer ) = P. sz[k]
21
-
22
21
bandwidthsmax (ops, f= bandwidths) = mapreduce (f, (t1, t2) -> max .(t1, t2), ops, init= (- 720 , - 720 )) #= approximate (-∞,-∞) =#
23
22
24
23
function PlusOperator (opsin:: Vector{O} , args... ) where {O<: Operator }
@@ -28,17 +27,12 @@ function PlusOperator{ET}(opsin::Vector{O},
28
27
bw:: Tuple{Any,Any} = bandwidthsmax (opsin),
29
28
sz:: Tuple{Any,Any} = size (first (opsin)),
30
29
bbw:: Tuple{Any,Any} = bandwidthsmax (opsin, blockbandwidths),
30
+ sbbw:: Tuple{Any,Any} = bandwidthsmax (opsin, subblockbandwidths),
31
31
) where {ET,O<: Operator{ET} }
32
32
33
- PlusOperator {ET,typeof(bw),typeof(sz),O,typeof(bbw)} (opsin, bw, sz, bbw)
33
+ PlusOperator {ET,typeof(bw),typeof(sz),O,typeof(bbw),typeof(sbbw) } (opsin, bw, sz, bbw, sbbw )
34
34
end
35
35
36
- bandwidths (P:: PlusOperator ) = P. bandwidths
37
- blockbandwidths (P:: PlusOperator ) = P. blockbandwidths
38
- subblockbandwidths (P:: PlusOperator ) = bandwidthsmax (P. ops, subblockbandwidths)
39
-
40
- israggedbelow (P:: PlusOperator ) = isbandedbelow (P) || all (israggedbelow, P. ops)
41
-
42
36
for (OP, mn) in ((:colstart , :min ), (:colstop , :max ), (:rowstart , :min ), (:rowstop , :max ))
43
37
defOP = Symbol (:default_ , OP)
44
38
@eval function $OP (P:: PlusOperator , k:: Integer )
@@ -63,19 +57,16 @@ function convert(::Type{Operator{T}}, P::PlusOperator) where {T}
63
57
ops = P. ops
64
58
PlusOperator (eltype (ops) <: Operator{T} ? ops :
65
59
_convertops (Operator{T}, ops),
66
- bandwidths (P), size (P), blockbandwidths (P)):: Operator{T}
60
+ bandwidths (P), size (P), blockbandwidths (P), subblockbandwidths (P) ):: Operator{T}
67
61
end
68
62
end
69
63
70
64
function promoteplus (opsin, sz= size (first (opsin)))
71
65
ops = filter (! iszeroop, opsin)
72
66
ET = promote_eltypeof (opsin)
73
67
v = promotespaces (ops)
74
- PlusOperator {ET} (convert_vector (v), bandwidthsmax (v), sz, bandwidthsmax (v, blockbandwidths))
75
- end
76
-
77
- for OP in (:domainspace , :rangespace )
78
- @eval $ OP (P:: PlusOperator ) = $ OP (first (P. ops))
68
+ PlusOperator {ET} (convert_vector (v), bandwidthsmax (v), sz,
69
+ bandwidthsmax (v, blockbandwidths), bandwidthsmax (v, subblockbandwidths))
79
70
end
80
71
81
72
domain (P:: PlusOperator ) = commondomain (P. ops)
@@ -224,13 +215,15 @@ BLAS.axpy!(α, S::SubOperator{T,OP}, A::AbstractMatrix) where {T,OP<:ConstantTim
224
215
225
216
226
217
227
- struct TimesOperator{T,BW,SZ,O<: Operator{T} ,BBW} <: Operator{T}
218
+ struct TimesOperator{T,BW,SZ,O<: Operator{T} ,BBW,SBBW } <: Operator{T}
228
219
ops:: Vector{O}
229
220
bandwidths:: BW
230
221
sz:: SZ
231
222
blockbandwidths:: BBW
223
+ subblockbandwidths:: SBBW
232
224
233
- function TimesOperator {T,BW,SZ,O,BBW} (ops:: Vector{O} , bw:: BW , sz:: SZ , bbw:: BBW ) where {T,O<: Operator{T} ,BW,SZ,BBW}
225
+ function TimesOperator {T,BW,SZ,O,BBW,SBBW} (ops:: Vector{O} , bw:: BW ,
226
+ sz:: SZ , bbw:: BBW , sbbw:: SBBW ) where {T,O<: Operator{T} ,BW,SZ,BBW,SBBW}
234
227
# check compatible
235
228
for k = 1 : length (ops)- 1
236
229
size (ops[k], 2 ) == size (ops[k+ 1 ], 1 ) || throw (ArgumentError (" incompatible operator sizes" ))
@@ -248,10 +241,12 @@ struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW} <: Operator{T}
248
241
newops = ops
249
242
end
250
243
251
- new {T,BW,SZ,O,BBW} (newops, bw, sz, bbw)
244
+ new {T,BW,SZ,O,BBW,SBBW } (newops, bw, sz, bbw, sbbw )
252
245
end
253
246
end
254
247
248
+ const PlusOrTimesOp = Union{PlusOperator,TimesOperator}
249
+
255
250
bandwidthssum (P, f= bandwidths) = mapreduce (f, (t1, t2) -> t1 .+ t2, P, init= (0 , 0 ))
256
251
_bandwidthssum (A:: Operator , B:: Operator , f= bandwidths) = __bandwidthssum (f (A), f (B))
257
252
__bandwidthssum (A:: NTuple{2,InfiniteCardinal{0}} , B:: NTuple{2,InfiniteCardinal{0}} ) = A
@@ -261,19 +256,22 @@ __bandwidthssum(A, B) = reduce((t1, t2) -> t1 .+ t2, (A, B), init=(0, 0))
261
256
262
257
_timessize (ops) = (size (first (ops), 1 ), size (last (ops), 2 ))
263
258
function TimesOperator (ops:: AbstractVector{O} ,
264
- bw:: Tuple{Any,Any} = bandwidthssum (ops),
265
- sz:: Tuple{Any,Any} = _timessize (ops),
266
- bbw:: Tuple{Any,Any} = bandwidthssum (ops, blockbandwidths),
267
- ) where {T,O<: Operator{T} }
268
- TimesOperator {T,typeof(bw),typeof(sz),O,typeof(bbw)} (convert_vector (ops), bw, sz, bbw)
259
+ bw:: Tuple{Any,Any} = bandwidthssum (ops),
260
+ sz:: Tuple{Any,Any} = _timessize (ops),
261
+ bbw:: Tuple{Any,Any} = bandwidthssum (ops, blockbandwidths),
262
+ sbbw:: Tuple{Any,Any} = bandwidthssum (ops, subblockbandwidths),
263
+ ) where {T,O<: Operator{T} }
264
+ TimesOperator {T,typeof(bw),typeof(sz),O,typeof(bbw),typeof(sbbw)} (convert_vector (ops),
265
+ bw, sz, bbw, sbbw)
269
266
end
270
267
271
268
_extractops (A:: TimesOperator , :: typeof (* )) = A. ops
272
269
273
270
function TimesOperator (A:: Operator , B:: Operator )
274
271
v = collateops (* , A, B)
275
272
TimesOperator (convert_vector (v), _bandwidthssum (A, B), _timessize ((A, B)),
276
- _bandwidthssum (A, B, blockbandwidths))
273
+ _bandwidthssum (A, B, blockbandwidths),
274
+ _bandwidthssum (A, B, subblockbandwidths))
277
275
end
278
276
279
277
@@ -286,7 +284,8 @@ function convert(::Type{Operator{T}}, P::TimesOperator) where {T}
286
284
ops = P. ops
287
285
TimesOperator (eltype (ops) <: Operator{T} ? ops :
288
286
_convertops (Operator{T}, ops),
289
- bandwidths (P), size (P), blockbandwidths (P)):: Operator{T}
287
+ bandwidths (P), size (P), blockbandwidths (P),
288
+ subblockbandwidths (P)):: Operator{T}
290
289
end
291
290
end
292
291
303
302
)
304
303
305
304
@assert length (opsin) > 1 " need at least 2 operators"
306
- ops, bw, bbw = __promotetimes (opsin, dsp, anytimesop)
307
- TimesOperator (ops, bw, sz, bbw)
305
+ ops, bw, bbw, sbbw = __promotetimes (opsin, dsp, anytimesop)
306
+ TimesOperator (ops, bw, sz, bbw, sbbw )
308
307
end
309
308
@inline function __promotetimes (opsin, dsp, anytimesop)
310
309
ops = Vector {Operator{promote_eltypeof(opsin)}} (undef, 0 )
322
321
end
323
322
end
324
323
end
325
- reverse! (ops), bandwidthssum (ops), bandwidthssum (ops, blockbandwidths)
324
+ reverse! (ops), bandwidthssum (ops), bandwidthssum (ops, blockbandwidths),
325
+ bandwidthssum (ops, subblockbandwidths)
326
326
end
327
+ _op_bws (op) = [op], bandwidths (op), blockbandwidths (op), subblockbandwidths (op)
327
328
@inline function __promotetimes (opsin:: Tuple{Operator,Operator} , dsp, anytimesop)
328
329
@assert ! any (Base. Fix2 (isa, TimesOperator), opsin) " TimesOperator should have been extracted already"
329
330
@@ -332,36 +333,37 @@ end
332
333
333
334
if op2 isa Conversion && op1 isa Conversion
334
335
op = Conversion (domainspace (op2), rangespace (op1))
335
- return [op], bandwidths (op), blockbandwidths (op)
336
+ return _op_bws (op)
336
337
elseif op2 isa Conversion
337
338
op = op1 → rangespace (op2)
338
- return [op], bandwidths (op), blockbandwidths (op)
339
+ return _op_bws (op)
339
340
elseif op1 isa Conversion
340
341
op = op2: domainspace (op1) → rangespace (op2)
341
- return [op], bandwidths (op), blockbandwidths (op)
342
+ return _op_bws (op)
342
343
else
343
344
op2_dsp = op2: dsp
344
345
op1_dsp = op1: rangespace (op2_dsp)
345
346
return [op1_dsp, op2_dsp], bandwidthssum ((op1_dsp, op2_dsp)),
346
- bandwidthssum ((op1_dsp, op2_dsp), blockbandwidths)
347
+ bandwidthssum ((op1_dsp, op2_dsp), blockbandwidths),
348
+ bandwidthssum ((op1_dsp, op2_dsp), subblockbandwidths)
347
349
end
348
350
end
349
351
350
- domainspace (P:: TimesOperator ) = domainspace (last (P. ops))
351
- rangespace (P:: TimesOperator ) = rangespace (first (P. ops))
352
+ domainspace (P:: PlusOrTimesOp ) = domainspace (last (P. ops))
353
+ rangespace (P:: PlusOrTimesOp ) = rangespace (first (P. ops))
352
354
353
355
domain (P:: TimesOperator ) = commondomain (P. ops)
354
356
355
- size (P:: TimesOperator , k:: Integer ) = P. sz[k]
356
- size (P:: TimesOperator ) = P. sz
357
+ size (P:: PlusOrTimesOp , k:: Integer ) = P. sz[k]
358
+ size (P:: PlusOrTimesOp ) = P. sz
357
359
358
- bandwidths (P:: TimesOperator ) = P. bandwidths
359
- blockbandwidths (P:: TimesOperator ) = P. blockbandwidths
360
- subblockbandwidths (P:: TimesOperator ) = bandwidthssum (P . ops, subblockbandwidths)
360
+ bandwidths (P:: PlusOrTimesOp ) = P. bandwidths
361
+ blockbandwidths (P:: PlusOrTimesOp ) = P. blockbandwidths
362
+ subblockbandwidths (P:: PlusOrTimesOp ) = P . subblockbandwidths
361
363
362
- isbandedblockbanded (P:: Union{PlusOperator,TimesOperator} ) = all (isbandedblockbanded, P. ops)
364
+ isbandedblockbanded (P:: PlusOrTimesOp ) = all (isbandedblockbanded, P. ops)
363
365
364
- israggedbelow (P:: TimesOperator ) = isbandedbelow (P) || all (israggedbelow, P. ops)
366
+ israggedbelow (P:: PlusOrTimesOp ) = isbandedbelow (P) || all (israggedbelow, P. ops)
365
367
366
368
Base. stride (P:: TimesOperator ) = mapreduce (stride, gcd, P. ops)
367
369
@@ -548,10 +550,12 @@ end
548
550
for OP in (:(adjoint), :(transpose))
549
551
@eval $ OP (A:: TimesOperator ) = TimesOperator (
550
552
strictconvert (Vector, reverse! (map ($ OP, A. ops))),
551
- reverse (bandwidths (A)), reverse (size (A)), reverse (blockbandwidths (A)))
553
+ reverse (bandwidths (A)), reverse (size (A)),
554
+ reverse (blockbandwidths (A)),
555
+ reverse (subblockbandwidths (A))
556
+ )
552
557
end
553
558
554
- const PlusOrTimesOp = Union{PlusOperator,TimesOperator}
555
559
anyplustimes (f, op:: Operator , ops... ) = anyplustimes (f, ops... )
556
560
anyplustimes (:: typeof (+ ), op:: PlusOperator , ops... ) = true
557
561
anyplustimes (:: typeof (* ), op:: TimesOperator , ops... ) = true
0 commit comments