Skip to content

Commit 6353b9c

Browse files
authored
Subblockbandwidths in plus/times operators (#337)
1 parent 787429b commit 6353b9c

File tree

1 file changed

+51
-47
lines changed

1 file changed

+51
-47
lines changed

src/Operators/general/algebra.jl

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,20 @@ export PlusOperator, TimesOperator, mul_coefficients
44

55

66

7-
struct PlusOperator{T,BW,SZ,O<:Operator{T},BBW} <: Operator{T}
7+
struct PlusOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW} <: Operator{T}
88
ops::Vector{O}
99
bandwidths::BW
1010
sz::SZ
1111
blockbandwidths::BBW
12+
subblockbandwidths::SBBW
1213

13-
function PlusOperator{T,BW,SZ,O,BBW}(opsin::Vector{O}, bi::BW, sz::SZ, bbw::BBW) where {T,O<:Operator{T},BW,SZ,BBW}
14+
function PlusOperator{T,BW,SZ,O,BBW,SBBW}(opsin::Vector{O}, bw::BW,
15+
sz::SZ, bbw::BBW, sbbw::SBBW) where {T,O<:Operator{T},BW,SZ,BBW,SBBW}
1416
all(x -> size(x) == sz, opsin) || throw("sizes of operators are incompatible")
15-
new{T,BW,SZ,O,BBW}(opsin, bi, sz, bbw)
17+
new{T,BW,SZ,O,BBW,SBBW}(opsin, bw, sz, bbw,sbbw)
1618
end
1719
end
1820

19-
size(P::PlusOperator) = P.sz
20-
size(P::PlusOperator, k::Integer) = P.sz[k]
21-
2221
bandwidthsmax(ops, f=bandwidths) = mapreduce(f, (t1, t2) -> max.(t1, t2), ops, init=(-720, -720)) #= approximate (-∞,-∞) =#
2322

2423
function PlusOperator(opsin::Vector{O}, args...) where {O<:Operator}
@@ -28,17 +27,12 @@ function PlusOperator{ET}(opsin::Vector{O},
2827
bw::Tuple{Any,Any}=bandwidthsmax(opsin),
2928
sz::Tuple{Any,Any}=size(first(opsin)),
3029
bbw::Tuple{Any,Any}=bandwidthsmax(opsin, blockbandwidths),
30+
sbbw::Tuple{Any,Any}=bandwidthsmax(opsin, subblockbandwidths),
3131
) where {ET,O<:Operator{ET}}
3232

33-
PlusOperator{ET,typeof(bw),typeof(sz),O,typeof(bbw)}(opsin, bw, sz, bbw)
33+
PlusOperator{ET,typeof(bw),typeof(sz),O,typeof(bbw),typeof(sbbw)}(opsin, bw, sz, bbw, sbbw)
3434
end
3535

36-
bandwidths(P::PlusOperator) = P.bandwidths
37-
blockbandwidths(P::PlusOperator) = P.blockbandwidths
38-
subblockbandwidths(P::PlusOperator) = bandwidthsmax(P.ops, subblockbandwidths)
39-
40-
israggedbelow(P::PlusOperator) = isbandedbelow(P) || all(israggedbelow, P.ops)
41-
4236
for (OP, mn) in ((:colstart, :min), (:colstop, :max), (:rowstart, :min), (:rowstop, :max))
4337
defOP = Symbol(:default_, OP)
4438
@eval function $OP(P::PlusOperator, k::Integer)
@@ -63,19 +57,16 @@ function convert(::Type{Operator{T}}, P::PlusOperator) where {T}
6357
ops = P.ops
6458
PlusOperator(eltype(ops) <: Operator{T} ? ops :
6559
_convertops(Operator{T}, ops),
66-
bandwidths(P), size(P), blockbandwidths(P))::Operator{T}
60+
bandwidths(P), size(P), blockbandwidths(P), subblockbandwidths(P))::Operator{T}
6761
end
6862
end
6963

7064
function promoteplus(opsin, sz=size(first(opsin)))
7165
ops = filter(!iszeroop, opsin)
7266
ET = promote_eltypeof(opsin)
7367
v = promotespaces(ops)
74-
PlusOperator{ET}(convert_vector(v), bandwidthsmax(v), sz, bandwidthsmax(v, blockbandwidths))
75-
end
76-
77-
for OP in (:domainspace, :rangespace)
78-
@eval $OP(P::PlusOperator) = $OP(first(P.ops))
68+
PlusOperator{ET}(convert_vector(v), bandwidthsmax(v), sz,
69+
bandwidthsmax(v, blockbandwidths), bandwidthsmax(v, subblockbandwidths))
7970
end
8071

8172
domain(P::PlusOperator) = commondomain(P.ops)
@@ -224,13 +215,15 @@ BLAS.axpy!(α, S::SubOperator{T,OP}, A::AbstractMatrix) where {T,OP<:ConstantTim
224215

225216

226217

227-
struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW} <: Operator{T}
218+
struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW} <: Operator{T}
228219
ops::Vector{O}
229220
bandwidths::BW
230221
sz::SZ
231222
blockbandwidths::BBW
223+
subblockbandwidths::SBBW
232224

233-
function TimesOperator{T,BW,SZ,O,BBW}(ops::Vector{O}, bw::BW, sz::SZ, bbw::BBW) where {T,O<:Operator{T},BW,SZ,BBW}
225+
function TimesOperator{T,BW,SZ,O,BBW,SBBW}(ops::Vector{O}, bw::BW,
226+
sz::SZ, bbw::BBW, sbbw::SBBW) where {T,O<:Operator{T},BW,SZ,BBW,SBBW}
234227
# check compatible
235228
for k = 1:length(ops)-1
236229
size(ops[k], 2) == size(ops[k+1], 1) || throw(ArgumentError("incompatible operator sizes"))
@@ -248,10 +241,12 @@ struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW} <: Operator{T}
248241
newops = ops
249242
end
250243

251-
new{T,BW,SZ,O,BBW}(newops, bw, sz, bbw)
244+
new{T,BW,SZ,O,BBW,SBBW}(newops, bw, sz, bbw, sbbw)
252245
end
253246
end
254247

248+
const PlusOrTimesOp = Union{PlusOperator,TimesOperator}
249+
255250
bandwidthssum(P, f=bandwidths) = mapreduce(f, (t1, t2) -> t1 .+ t2, P, init=(0, 0))
256251
_bandwidthssum(A::Operator, B::Operator, f=bandwidths) = __bandwidthssum(f(A), f(B))
257252
__bandwidthssum(A::NTuple{2,InfiniteCardinal{0}}, B::NTuple{2,InfiniteCardinal{0}}) = A
@@ -261,19 +256,22 @@ __bandwidthssum(A, B) = reduce((t1, t2) -> t1 .+ t2, (A, B), init=(0, 0))
261256

262257
_timessize(ops) = (size(first(ops), 1), size(last(ops), 2))
263258
function TimesOperator(ops::AbstractVector{O},
264-
bw::Tuple{Any,Any}=bandwidthssum(ops),
265-
sz::Tuple{Any,Any}=_timessize(ops),
266-
bbw::Tuple{Any,Any}=bandwidthssum(ops, blockbandwidths),
267-
) where {T,O<:Operator{T}}
268-
TimesOperator{T,typeof(bw),typeof(sz),O,typeof(bbw)}(convert_vector(ops), bw, sz, bbw)
259+
bw::Tuple{Any,Any}=bandwidthssum(ops),
260+
sz::Tuple{Any,Any}=_timessize(ops),
261+
bbw::Tuple{Any,Any}=bandwidthssum(ops, blockbandwidths),
262+
sbbw::Tuple{Any,Any}=bandwidthssum(ops, subblockbandwidths),
263+
) where {T,O<:Operator{T}}
264+
TimesOperator{T,typeof(bw),typeof(sz),O,typeof(bbw),typeof(sbbw)}(convert_vector(ops),
265+
bw, sz, bbw, sbbw)
269266
end
270267

271268
_extractops(A::TimesOperator, ::typeof(*)) = A.ops
272269

273270
function TimesOperator(A::Operator, B::Operator)
274271
v = collateops(*, A, B)
275272
TimesOperator(convert_vector(v), _bandwidthssum(A, B), _timessize((A, B)),
276-
_bandwidthssum(A, B, blockbandwidths))
273+
_bandwidthssum(A, B, blockbandwidths),
274+
_bandwidthssum(A, B, subblockbandwidths))
277275
end
278276

279277

@@ -286,7 +284,8 @@ function convert(::Type{Operator{T}}, P::TimesOperator) where {T}
286284
ops = P.ops
287285
TimesOperator(eltype(ops) <: Operator{T} ? ops :
288286
_convertops(Operator{T}, ops),
289-
bandwidths(P), size(P), blockbandwidths(P))::Operator{T}
287+
bandwidths(P), size(P), blockbandwidths(P),
288+
subblockbandwidths(P))::Operator{T}
290289
end
291290
end
292291

@@ -303,8 +302,8 @@ end
303302
)
304303

305304
@assert length(opsin) > 1 "need at least 2 operators"
306-
ops, bw, bbw = __promotetimes(opsin, dsp, anytimesop)
307-
TimesOperator(ops, bw, sz, bbw)
305+
ops, bw, bbw, sbbw = __promotetimes(opsin, dsp, anytimesop)
306+
TimesOperator(ops, bw, sz, bbw, sbbw)
308307
end
309308
@inline function __promotetimes(opsin, dsp, anytimesop)
310309
ops = Vector{Operator{promote_eltypeof(opsin)}}(undef, 0)
@@ -322,8 +321,10 @@ end
322321
end
323322
end
324323
end
325-
reverse!(ops), bandwidthssum(ops), bandwidthssum(ops, blockbandwidths)
324+
reverse!(ops), bandwidthssum(ops), bandwidthssum(ops, blockbandwidths),
325+
bandwidthssum(ops, subblockbandwidths)
326326
end
327+
_op_bws(op) = [op], bandwidths(op), blockbandwidths(op), subblockbandwidths(op)
327328
@inline function __promotetimes(opsin::Tuple{Operator,Operator}, dsp, anytimesop)
328329
@assert !any(Base.Fix2(isa, TimesOperator), opsin) "TimesOperator should have been extracted already"
329330

@@ -332,36 +333,37 @@ end
332333

333334
if op2 isa Conversion && op1 isa Conversion
334335
op = Conversion(domainspace(op2), rangespace(op1))
335-
return [op], bandwidths(op), blockbandwidths(op)
336+
return _op_bws(op)
336337
elseif op2 isa Conversion
337338
op = op1 rangespace(op2)
338-
return [op], bandwidths(op), blockbandwidths(op)
339+
return _op_bws(op)
339340
elseif op1 isa Conversion
340341
op = op2:domainspace(op1) rangespace(op2)
341-
return [op], bandwidths(op), blockbandwidths(op)
342+
return _op_bws(op)
342343
else
343344
op2_dsp = op2:dsp
344345
op1_dsp = op1:rangespace(op2_dsp)
345346
return [op1_dsp, op2_dsp], bandwidthssum((op1_dsp, op2_dsp)),
346-
bandwidthssum((op1_dsp, op2_dsp), blockbandwidths)
347+
bandwidthssum((op1_dsp, op2_dsp), blockbandwidths),
348+
bandwidthssum((op1_dsp, op2_dsp), subblockbandwidths)
347349
end
348350
end
349351

350-
domainspace(P::TimesOperator) = domainspace(last(P.ops))
351-
rangespace(P::TimesOperator) = rangespace(first(P.ops))
352+
domainspace(P::PlusOrTimesOp) = domainspace(last(P.ops))
353+
rangespace(P::PlusOrTimesOp) = rangespace(first(P.ops))
352354

353355
domain(P::TimesOperator) = commondomain(P.ops)
354356

355-
size(P::TimesOperator, k::Integer) = P.sz[k]
356-
size(P::TimesOperator) = P.sz
357+
size(P::PlusOrTimesOp, k::Integer) = P.sz[k]
358+
size(P::PlusOrTimesOp) = P.sz
357359

358-
bandwidths(P::TimesOperator) = P.bandwidths
359-
blockbandwidths(P::TimesOperator) = P.blockbandwidths
360-
subblockbandwidths(P::TimesOperator) = bandwidthssum(P.ops, subblockbandwidths)
360+
bandwidths(P::PlusOrTimesOp) = P.bandwidths
361+
blockbandwidths(P::PlusOrTimesOp) = P.blockbandwidths
362+
subblockbandwidths(P::PlusOrTimesOp) = P.subblockbandwidths
361363

362-
isbandedblockbanded(P::Union{PlusOperator,TimesOperator}) = all(isbandedblockbanded, P.ops)
364+
isbandedblockbanded(P::PlusOrTimesOp) = all(isbandedblockbanded, P.ops)
363365

364-
israggedbelow(P::TimesOperator) = isbandedbelow(P) || all(israggedbelow, P.ops)
366+
israggedbelow(P::PlusOrTimesOp) = isbandedbelow(P) || all(israggedbelow, P.ops)
365367

366368
Base.stride(P::TimesOperator) = mapreduce(stride, gcd, P.ops)
367369

@@ -548,10 +550,12 @@ end
548550
for OP in (:(adjoint), :(transpose))
549551
@eval $OP(A::TimesOperator) = TimesOperator(
550552
strictconvert(Vector, reverse!(map($OP, A.ops))),
551-
reverse(bandwidths(A)), reverse(size(A)), reverse(blockbandwidths(A)))
553+
reverse(bandwidths(A)), reverse(size(A)),
554+
reverse(blockbandwidths(A)),
555+
reverse(subblockbandwidths(A))
556+
)
552557
end
553558

554-
const PlusOrTimesOp = Union{PlusOperator,TimesOperator}
555559
anyplustimes(f, op::Operator, ops...) = anyplustimes(f, ops...)
556560
anyplustimes(::typeof(+), op::PlusOperator, ops...) = true
557561
anyplustimes(::typeof(*), op::TimesOperator, ops...) = true

0 commit comments

Comments
 (0)