@@ -46,7 +46,8 @@ function convert(::Type{Operator{T}}, P::PlusOperator) where T
46
46
P
47
47
else
48
48
ops = P. ops
49
- PlusOperator (ops isa AbstractVector{<: Operator{T} } ? ops : map (x -> strictconvert (Operator{T}, x), ops),
49
+ PlusOperator (ops isa AbstractVector{<: Operator{T} } ? ops :
50
+ map (x -> strictconvert (Operator{T}, x), ops),
50
51
P. bandwidths,P. sz):: Operator{T}
51
52
end
52
53
end
64
65
domain (P:: PlusOperator ) = commondomain (P. ops)
65
66
66
67
_promote_eltypeof (As... ) = _promote_eltypeof (As)
67
- _promote_eltypeof (As:: Union{Vector , Tuple} ) = mapreduce (eltype, promote_type, As)
68
- _promote_eltypeof (As:: Vector{ Operator{T}} ) where {T} = T
68
+ _promote_eltypeof (As:: Union{AbstractVector , Tuple} ) = mapreduce (eltype, promote_type, As)
69
+ _promote_eltypeof (As:: AbstractVector{<: Operator{T}} ) where {T} = T
69
70
70
- _extractops (A, :: Any ) = [A]
71
+ _extractops (A, :: Any ) = SVector {1} (A)
71
72
_extractops (A:: PlusOperator , :: typeof (+ )) = A. ops
72
73
73
74
function + (A:: Operator ,B:: Operator )
74
- v = [ _extractops (A, + ); _extractops (B, + )]
75
+ v = collateops ( + , A, B)
75
76
promoteplus (v, size (A))
76
77
end
77
78
# Optimization for 3-term sum
78
79
function + (A:: Operator ,B:: Operator ,C:: Operator )
79
- v = [ _extractops (A, + ); _extractops ( B, + ); _extractops (C, + )]
80
+ v = collateops ( + , A, B, C)
80
81
promoteplus (v, size (A))
81
82
end
82
83
@@ -247,10 +248,10 @@ __bandwidthssum(A, B::NTuple{2,InfiniteCardinal{0}}) = B
247
248
__bandwidthssum (A, B) = reduce ((t1, t2) -> t1 .+ t2, (A, B), init = (0 ,0 ))
248
249
249
250
_timessize (ops) = (size (first (ops),1 ), size (last (ops),2 ))
250
- function TimesOperator (ops:: Vector {O} ,
251
+ function TimesOperator (ops:: AbstractVector {O} ,
251
252
bi:: Tuple{Any,Any} = bandwidthssum (ops),
252
253
sz:: Tuple{Any,Any} = _timessize (ops)) where {T,O<: Operator{T} }
253
- TimesOperator {T,typeof(bi),typeof(sz),O} (ops,bi,sz)
254
+ TimesOperator {T,typeof(bi),typeof(sz),O} (convert_vector ( ops) ,bi,sz)
254
255
end
255
256
256
257
_extractops (A:: TimesOperator , :: typeof (* )) = A. ops
@@ -268,7 +269,8 @@ function convert(::Type{Operator{T}},P::TimesOperator) where T
268
269
P
269
270
else
270
271
ops = P. ops
271
- TimesOperator (ops isa AbstractVector{<: Operator{T} } ? ops : map (x -> strictconvert (Operator{T}, x), ops) ,
272
+ TimesOperator (ops isa AbstractVector{<: Operator{T} } ? ops :
273
+ map (x -> strictconvert (Operator{T}, x), ops) ,
272
274
bandwidths (P), size (P))
273
275
end
274
276
end
@@ -517,17 +519,22 @@ for OP in (:(adjoint),:(transpose))
517
519
reverse (bandwidths (A)), reverse (size (A)))
518
520
end
519
521
520
- _collateops (A:: TimesOperator , B:: TimesOperator , :: typeof (* )) = [_extractops (A, * ); _extractops (B, * )]
521
- _collateops (A:: TimesOperator , B:: Operator , :: typeof (* )) = [_extractops (A, * ); _extractops (B, * )]
522
- _collateops (A:: Operator , B:: TimesOperator , :: typeof (* )) = [_extractops (A, * ); _extractops (B, * )]
523
- _collateops (A:: Operator , B:: Operator , :: typeof (* )) = (A, B)
522
+ const PlusOrTimesOp = Union{PlusOperator, TimesOperator}
523
+ anyplustimes (op:: Operator , ops... ) = anyplustimes (ops... )
524
+ anyplustimes (op:: PlusOrTimesOp , ops... ) = true
525
+ anyplustimes () = false
526
+
527
+ collateops (op, As... ) = collateops (op, Val (anyplustimes (As... )), As... )
528
+ collateops (op, :: Val{true} , As... ) = mapreduce (x -> _extractops (x, op), vcat, As)
529
+ collateops (op, :: Val{false} , As... ) = As
530
+
524
531
function * (A:: Operator ,B:: Operator )
525
532
if isconstop (A)
526
533
promoterangespace (strictconvert (Number,A)* B,rangespace (A))
527
534
elseif isconstop (B)
528
535
promotedomainspace (strictconvert (Number,B)* A,domainspace (B))
529
536
else
530
- promotetimes (_collateops (A, B, * ),
537
+ promotetimes (collateops ( * , A, B ),
531
538
domainspace (B), _timessize ((A,B)), false )
532
539
end
533
540
end
0 commit comments