@@ -221,7 +221,22 @@ axpy!(α, S::SubOperator{T,OP}, A::AbstractMatrix) where {T,OP<:ConstantTimesOpe
221
221
222
222
223
223
224
+ function check_times (ops)
225
+ for k = 1 : length (ops)- 1
226
+ size (ops[k], 2 ) == size (ops[k+ 1 ], 1 ) || throw (ArgumentError (" incompatible operator sizes" ))
227
+ spacescompatible (domainspace (ops[k]), rangespace (ops[k+ 1 ])) || throw (ArgumentError (" incompatible spaces at index $k " ))
228
+ end
229
+ return nothing
230
+ end
224
231
232
+ function splice_times (ops)
233
+ timesinds = findall (x -> isa (x, TimesOperator), ops)
234
+ newops = copy (ops)
235
+ for ind in timesinds
236
+ splice! (newops, ind, ops[ind]. ops)
237
+ end
238
+ newops
239
+ end
225
240
226
241
struct TimesOperator{T,BW,SZ,O<: Operator{T} ,BBW,SBBW} <: Operator{T}
227
242
ops:: Vector{O}
@@ -231,28 +246,36 @@ struct TimesOperator{T,BW,SZ,O<:Operator{T},BBW,SBBW} <: Operator{T}
231
246
subblockbandwidths:: SBBW
232
247
isbandedblockbanded:: Bool
233
248
israggedbelow:: Bool
249
+ isafunctional:: Bool
234
250
235
- function TimesOperator {T,BW,SZ,O,BBW,SBBW} (ops:: Vector{O} , bw:: BW ,
236
- sz:: SZ , bbw:: BBW , sbbw:: SBBW ,
237
- ibbb:: Bool , irb:: Bool ) where {T,O<: Operator{T} ,BW,SZ,BBW,SBBW}
238
- # check compatible
239
- for k = 1 : length (ops)- 1
240
- size (ops[k], 2 ) == size (ops[k+ 1 ], 1 ) || throw (ArgumentError (" incompatible operator sizes" ))
241
- spacescompatible (domainspace (ops[k]), rangespace (ops[k+ 1 ])) || throw (ArgumentError (" incompatible spaces at index $k " ))
242
- end
251
+ @static if VERSION >= v " 1.8"
252
+ Base. @constprop :aggressive function TimesOperator {T,BW,SZ,O,BBW,SBBW} (ops:: Vector{O} , bw:: BW ,
253
+ sz:: SZ , bbw:: BBW , sbbw:: SBBW ,
254
+ ibbb:: Bool , irb:: Bool , isaf:: Bool ;
255
+ anytimesop = any (x -> x isa TimesOperator, ops)) where {T,O<: Operator{T} ,BW,SZ,BBW,SBBW}
243
256
244
- # remove TimesOperators buried inside ops
245
- timesinds = findall (x -> isa (x, TimesOperator), ops)
246
- if ! isempty (timesinds)
247
- newops = copy (ops)
248
- for ind in timesinds
249
- splice! (newops, ind, ops[ind]. ops)
250
- end
251
- else
252
- newops = ops
257
+ # check compatible
258
+ check_times (ops)
259
+
260
+ # remove TimesOperators buried inside ops
261
+ newops = anytimesop ? splice_times (ops) : ops
262
+
263
+ new {T,BW,SZ,O,BBW,SBBW} (newops, bw, sz, bbw, sbbw, ibbb, irb, isaf)
253
264
end
265
+ else
266
+ function TimesOperator {T,BW,SZ,O,BBW,SBBW} (ops:: Vector{O} , bw:: BW ,
267
+ sz:: SZ , bbw:: BBW , sbbw:: SBBW ,
268
+ ibbb:: Bool , irb:: Bool , isaf:: Bool ;
269
+ anytimesop = any (x -> x isa TimesOperator, ops)) where {T,O<: Operator{T} ,BW,SZ,BBW,SBBW}
270
+
271
+ # check compatible
272
+ check_times (ops)
273
+
274
+ # remove TimesOperators buried inside ops
275
+ newops = anytimesop ? splice_times (ops) : ops
254
276
255
- new {T,BW,SZ,O,BBW,SBBW} (newops, bw, sz, bbw, sbbw, ibbb, irb)
277
+ new {T,BW,SZ,O,BBW,SBBW} (newops, bw, sz, bbw, sbbw, ibbb, irb, isaf)
278
+ end
256
279
end
257
280
end
258
281
@@ -273,9 +296,11 @@ function TimesOperator(ops::AbstractVector{O},
273
296
sbbw:: Tuple{Any,Any} = bandwidthssum (ops, subblockbandwidths),
274
297
ibbb:: Bool = all (isbandedblockbanded, ops),
275
298
irb:: Bool = all (israggedbelow, ops),
299
+ isaf:: Bool = sz[1 ] == 1 && isconstspace (rangespace (first (ops)));
300
+ anytimesop = any (x -> x isa TimesOperator, ops),
276
301
) where {O<: Operator }
277
302
TimesOperator {eltype(O),typeof(bw),typeof(sz),O,typeof(bbw),typeof(sbbw)} (
278
- convert_vector (ops), bw, sz, bbw, sbbw, ibbb, irb)
303
+ convert_vector (ops), bw, sz, bbw, sbbw, ibbb, irb, isaf; anytimesop )
279
304
end
280
305
281
306
_extractops (A:: TimesOperator , :: typeof (* )) = A. ops
@@ -284,9 +309,13 @@ function TimesOperator(A::Operator, B::Operator)
284
309
v = collateops (* , A, B)
285
310
ibbb = all (isbandedblockbanded, (A, B))
286
311
irb = all (israggedbelow, (A, B))
287
- TimesOperator (convert_vector (v), _bandwidthssum (A, B), _timessize ((A, B)),
312
+ sz = _timessize ((A, B))
313
+ isaf = sz[1 ] == 1 && isconstspace (rangespace (A))
314
+ anytimesop = any (x -> x isa TimesOperator, (A,B))
315
+ TimesOperator (convert_vector (v), _bandwidthssum (A, B), sz,
288
316
_bandwidthssum (A, B, blockbandwidths),
289
- _bandwidthssum (A, B, subblockbandwidths), ibbb, irb)
317
+ _bandwidthssum (A, B, subblockbandwidths), ibbb, irb, isaf;
318
+ anytimesop)
290
319
end
291
320
292
321
@@ -301,7 +330,7 @@ function convert(::Type{Operator{T}}, P::TimesOperator) where {T}
301
330
_convertops (Operator{T}, ops),
302
331
bandwidths (P), size (P), blockbandwidths (P),
303
332
subblockbandwidths (P), isbandedblockbanded (P),
304
- israggedbelow (P)):: Operator{T}
333
+ israggedbelow (P), P . isafunctional, anytimesop = false ):: Operator{T}
305
334
end
306
335
end
307
336
318
347
@assert length (opsin) > 1 " need at least 2 operators"
319
348
ops, bw, bbw, sbbw, ibbb, irb = __promotetimes (opsin, dsp, anytimesop)
320
349
sz = _timessize (ops)
321
- TimesOperator (convert_vector (ops), bw, sz, bbw, sbbw, ibbb, irb)
350
+ isaf = sz[1 ] == 1 && isconstspace (rangespace (first (ops)))
351
+ anytimesop = any (x -> x isa TimesOperator, ops)
352
+ TimesOperator (convert_vector (ops), bw, sz, bbw, sbbw, ibbb, irb, isaf; anytimesop)
322
353
end
323
354
function __promotetimes (opsin, dsp, anytimesop)
324
355
ops = Vector {Operator{promote_eltypeof(opsin)}} (undef, 0 )
@@ -388,6 +419,8 @@ isbandedblockbanded(P::PlusOrTimesOp) = P.isbandedblockbanded
388
419
389
420
israggedbelow (P:: PlusOrTimesOp ) = P. israggedbelow
390
421
422
+ isafunctional (T:: TimesOperator ) = T. isafunctional
423
+
391
424
Base. stride (P:: TimesOperator ) = mapreduce (stride, gcd, P. ops)
392
425
393
426
for OP in (:rowstart , :rowstop )
@@ -577,6 +610,7 @@ for OP in (:(adjoint), :(transpose))
577
610
reverse (blockbandwidths (A)),
578
611
reverse (subblockbandwidths (A)),
579
612
isbandedblockbanded (A),
613
+ anytimesop = false ,
580
614
)
581
615
end
582
616
0 commit comments