Skip to content

Commit ee171b6

Browse files
committed
Store domainspace in TimesOperator
1 parent 0202834 commit ee171b6

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

src/Operators/general/algebra.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,22 +223,26 @@ axpy!(α, S::SubOperator{T,OP}, A::AbstractMatrix) where {T,OP<:ConstantTimesOpe
223223

224224

225225

226-
struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW} <: Operator{T}
226+
struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW,D} <: Operator{T}
227227
ops::Vector{O}
228228
bandwidths::BW
229229
sz::SZ
230230
blockbandwidths::BBW
231231
subblockbandwidths::SBBW
232232
isbandedblockbanded::Bool
233233
israggedbelow::Bool
234+
domainspace::D
234235

235236
function TimesOperator{T,BW,SZ,O,BBW,SBBW}(ops::Vector{O}, bw::BW,
236237
sz::SZ, bbw::BBW, sbbw::SBBW,
237-
ibbb::Bool, irb::Bool) where {T,O<:Operator{T},BW,SZ,BBW,SBBW}
238-
# check compatible
238+
ibbb::Bool, irb::Bool, dsp::D) where {T,O<:Operator{T},BW,SZ,BBW,SBBW,D}
239+
240+
dsp == domainspace(ops[end]) || throw(ArgumentError("incompatible domainspace"))
241+
242+
# check compatibility
239243
for k = 1:length(ops)-1
240244
size(ops[k], 2) == size(ops[k+1], 1) || throw(ArgumentError("incompatible operator sizes"))
241-
spacescompatible(domainspace(ops[k]), rangespace(ops[k+1])) || throw(ArgumentError("imcompatible spaces at index $k"))
245+
spacescompatible(domainspace(ops[k]), rangespace(ops[k+1])) || throw(ArgumentError("incompatible spaces at index $k"))
242246
end
243247

244248
# remove TimesOperators buried inside ops
@@ -252,7 +256,7 @@ struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW} <: Operator{T}
252256
newops = ops
253257
end
254258

255-
new{T,BW,SZ,O,BBW,SBBW}(newops, bw, sz, bbw, sbbw, ibbb, irb)
259+
new{T,BW,SZ,O,BBW,SBBW,D}(newops, bw, sz, bbw, sbbw, ibbb, irb, dsp)
256260
end
257261
end
258262

@@ -273,9 +277,10 @@ function TimesOperator(ops::AbstractVector{O},
273277
sbbw::Tuple{Any,Any}=bandwidthssum(ops, subblockbandwidths),
274278
ibbb::Bool=all(isbandedblockbanded, ops),
275279
irb::Bool=all(israggedbelow, ops),
280+
dsp = domainspace(last(ops)),
276281
) where {O<:Operator}
277282
TimesOperator{eltype(O),typeof(bw),typeof(sz),O,typeof(bbw),typeof(sbbw)}(
278-
convert_vector(ops), bw, sz, bbw, sbbw, ibbb, irb)
283+
convert_vector(ops), bw, sz, bbw, sbbw, ibbb, irb, dsp)
279284
end
280285

281286
_extractops(A::TimesOperator, ::typeof(*)) = A.ops
@@ -284,9 +289,10 @@ function TimesOperator(A::Operator, B::Operator)
284289
v = collateops(*, A, B)
285290
ibbb = all(isbandedblockbanded, (A, B))
286291
irb = all(israggedbelow, (A, B))
292+
dsp = domainspace(B)
287293
TimesOperator(convert_vector(v), _bandwidthssum(A, B), _timessize((A, B)),
288294
_bandwidthssum(A, B, blockbandwidths),
289-
_bandwidthssum(A, B, subblockbandwidths), ibbb, irb)
295+
_bandwidthssum(A, B, subblockbandwidths), ibbb, irb, dsp)
290296
end
291297

292298

@@ -301,7 +307,7 @@ function convert(::Type{Operator{T}}, P::TimesOperator) where {T}
301307
_convertops(Operator{T}, ops),
302308
bandwidths(P), size(P), blockbandwidths(P),
303309
subblockbandwidths(P), isbandedblockbanded(P),
304-
israggedbelow(P))::Operator{T}
310+
israggedbelow(P), domainspace(P))::Operator{T}
305311
end
306312
end
307313

@@ -318,7 +324,7 @@ end
318324

319325
@assert length(opsin) > 1 "need at least 2 operators"
320326
ops, bw, bbw, sbbw, ibbb, irb = __promotetimes(opsin, dsp, anytimesop)
321-
TimesOperator(ops, bw, sz, bbw, sbbw, ibbb, irb)
327+
TimesOperator(ops, bw, sz, bbw, sbbw, ibbb, irb, dsp)
322328
end
323329
function __promotetimes(opsin, dsp, anytimesop)
324330
ops = Vector{Operator{promote_eltypeof(opsin)}}(undef, 0)
@@ -372,7 +378,8 @@ end
372378
end
373379
end
374380

375-
domainspace(P::PlusOrTimesOp) = domainspace(last(P.ops))
381+
domainspace(P::PlusOperator) = domainspace(last(P.ops))
382+
domainspace(T::TimesOperator) = T.domainspace
376383
rangespace(P::PlusOrTimesOp) = rangespace(first(P.ops))
377384

378385
domain(P::TimesOperator) = commondomain(P.ops)

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,10 @@ end
240240
f = Fun(sp, coeff)
241241
for sp2 in Any[(), (sp,)]
242242
M = Multiplication(f, sp2...)
243-
a = (M * M) * M
244-
b = M * (M * M)
243+
a = TimesOperator(M, M) * M
244+
b = M * TimesOperator(M, M)
245245
@test a == b
246+
@test (@inferred domainspace(a)) == domainspace(M)
246247
@test bandwidths(a) == bandwidths(b)
247248
end
248249
@testset "unwrap TimesOperator" begin

0 commit comments

Comments
 (0)