Skip to content

Commit ad35494

Browse files
authored
store isafunctional in TimesOperator (#447)
1 parent 643492d commit ad35494

File tree

2 files changed

+58
-24
lines changed

2 files changed

+58
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunBase"
22
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
3-
version = "0.8.13"
3+
version = "0.8.14"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/Operators/general/algebra.jl

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,22 @@ axpy!(α, S::SubOperator{T,OP}, A::AbstractMatrix) where {T,OP<:ConstantTimesOpe
221221

222222

223223

224+
function check_times(ops)
225+
for k = 1:length(ops)-1
226+
size(ops[k], 2) == size(ops[k+1], 1) || throw(ArgumentError("incompatible operator sizes"))
227+
spacescompatible(domainspace(ops[k]), rangespace(ops[k+1])) || throw(ArgumentError("incompatible spaces at index $k"))
228+
end
229+
return nothing
230+
end
224231

232+
function splice_times(ops)
233+
timesinds = findall(x -> isa(x, TimesOperator), ops)
234+
newops = copy(ops)
235+
for ind in timesinds
236+
splice!(newops, ind, ops[ind].ops)
237+
end
238+
newops
239+
end
225240

226241
struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW} <: Operator{T}
227242
ops::Vector{O}
@@ -231,28 +246,36 @@ struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW} <: Operator{T}
231246
subblockbandwidths::SBBW
232247
isbandedblockbanded::Bool
233248
israggedbelow::Bool
249+
isafunctional::Bool
234250

235-
function TimesOperator{T,BW,SZ,O,BBW,SBBW}(ops::Vector{O}, bw::BW,
236-
sz::SZ, bbw::BBW, sbbw::SBBW,
237-
ibbb::Bool, irb::Bool) where {T,O<:Operator{T},BW,SZ,BBW,SBBW}
238-
# check compatible
239-
for k = 1:length(ops)-1
240-
size(ops[k], 2) == size(ops[k+1], 1) || throw(ArgumentError("incompatible operator sizes"))
241-
spacescompatible(domainspace(ops[k]), rangespace(ops[k+1])) || throw(ArgumentError("incompatible spaces at index $k"))
242-
end
251+
@static if VERSION >= v"1.8"
252+
Base.@constprop :aggressive function TimesOperator{T,BW,SZ,O,BBW,SBBW}(ops::Vector{O}, bw::BW,
253+
sz::SZ, bbw::BBW, sbbw::SBBW,
254+
ibbb::Bool, irb::Bool, isaf::Bool;
255+
anytimesop = any(x -> x isa TimesOperator, ops)) where {T,O<:Operator{T},BW,SZ,BBW,SBBW}
243256

244-
# remove TimesOperators buried inside ops
245-
timesinds = findall(x -> isa(x, TimesOperator), ops)
246-
if !isempty(timesinds)
247-
newops = copy(ops)
248-
for ind in timesinds
249-
splice!(newops, ind, ops[ind].ops)
250-
end
251-
else
252-
newops = ops
257+
# check compatible
258+
check_times(ops)
259+
260+
# remove TimesOperators buried inside ops
261+
newops = anytimesop ? splice_times(ops) : ops
262+
263+
new{T,BW,SZ,O,BBW,SBBW}(newops, bw, sz, bbw, sbbw, ibbb, irb, isaf)
253264
end
265+
else
266+
function TimesOperator{T,BW,SZ,O,BBW,SBBW}(ops::Vector{O}, bw::BW,
267+
sz::SZ, bbw::BBW, sbbw::SBBW,
268+
ibbb::Bool, irb::Bool, isaf::Bool;
269+
anytimesop = any(x -> x isa TimesOperator, ops)) where {T,O<:Operator{T},BW,SZ,BBW,SBBW}
270+
271+
# check compatible
272+
check_times(ops)
273+
274+
# remove TimesOperators buried inside ops
275+
newops = anytimesop ? splice_times(ops) : ops
254276

255-
new{T,BW,SZ,O,BBW,SBBW}(newops, bw, sz, bbw, sbbw, ibbb, irb)
277+
new{T,BW,SZ,O,BBW,SBBW}(newops, bw, sz, bbw, sbbw, ibbb, irb, isaf)
278+
end
256279
end
257280
end
258281

@@ -273,9 +296,11 @@ function TimesOperator(ops::AbstractVector{O},
273296
sbbw::Tuple{Any,Any}=bandwidthssum(ops, subblockbandwidths),
274297
ibbb::Bool=all(isbandedblockbanded, ops),
275298
irb::Bool=all(israggedbelow, ops),
299+
isaf::Bool = sz[1] == 1 && isconstspace(rangespace(first(ops)));
300+
anytimesop = any(x -> x isa TimesOperator, ops),
276301
) where {O<:Operator}
277302
TimesOperator{eltype(O),typeof(bw),typeof(sz),O,typeof(bbw),typeof(sbbw)}(
278-
convert_vector(ops), bw, sz, bbw, sbbw, ibbb, irb)
303+
convert_vector(ops), bw, sz, bbw, sbbw, ibbb, irb, isaf; anytimesop)
279304
end
280305

281306
_extractops(A::TimesOperator, ::typeof(*)) = A.ops
@@ -284,9 +309,13 @@ function TimesOperator(A::Operator, B::Operator)
284309
v = collateops(*, A, B)
285310
ibbb = all(isbandedblockbanded, (A, B))
286311
irb = all(israggedbelow, (A, B))
287-
TimesOperator(convert_vector(v), _bandwidthssum(A, B), _timessize((A, B)),
312+
sz = _timessize((A, B))
313+
isaf = sz[1] == 1 && isconstspace(rangespace(A))
314+
anytimesop = any(x -> x isa TimesOperator, (A,B))
315+
TimesOperator(convert_vector(v), _bandwidthssum(A, B), sz,
288316
_bandwidthssum(A, B, blockbandwidths),
289-
_bandwidthssum(A, B, subblockbandwidths), ibbb, irb)
317+
_bandwidthssum(A, B, subblockbandwidths), ibbb, irb, isaf;
318+
anytimesop)
290319
end
291320

292321

@@ -301,7 +330,7 @@ function convert(::Type{Operator{T}}, P::TimesOperator) where {T}
301330
_convertops(Operator{T}, ops),
302331
bandwidths(P), size(P), blockbandwidths(P),
303332
subblockbandwidths(P), isbandedblockbanded(P),
304-
israggedbelow(P))::Operator{T}
333+
israggedbelow(P), P.isafunctional, anytimesop = false)::Operator{T}
305334
end
306335
end
307336

@@ -318,7 +347,9 @@ end
318347
@assert length(opsin) > 1 "need at least 2 operators"
319348
ops, bw, bbw, sbbw, ibbb, irb = __promotetimes(opsin, dsp, anytimesop)
320349
sz = _timessize(ops)
321-
TimesOperator(convert_vector(ops), bw, sz, bbw, sbbw, ibbb, irb)
350+
isaf = sz[1] == 1 && isconstspace(rangespace(first(ops)))
351+
anytimesop = any(x -> x isa TimesOperator, ops)
352+
TimesOperator(convert_vector(ops), bw, sz, bbw, sbbw, ibbb, irb, isaf; anytimesop)
322353
end
323354
function __promotetimes(opsin, dsp, anytimesop)
324355
ops = Vector{Operator{promote_eltypeof(opsin)}}(undef, 0)
@@ -388,6 +419,8 @@ isbandedblockbanded(P::PlusOrTimesOp) = P.isbandedblockbanded
388419

389420
israggedbelow(P::PlusOrTimesOp) = P.israggedbelow
390421

422+
isafunctional(T::TimesOperator) = T.isafunctional
423+
391424
Base.stride(P::TimesOperator) = mapreduce(stride, gcd, P.ops)
392425

393426
for OP in (:rowstart, :rowstop)
@@ -577,6 +610,7 @@ for OP in (:(adjoint), :(transpose))
577610
reverse(blockbandwidths(A)),
578611
reverse(subblockbandwidths(A)),
579612
isbandedblockbanded(A),
613+
anytimesop = false,
580614
)
581615
end
582616

0 commit comments

Comments
 (0)