Skip to content

Commit 67978db

Browse files
authored
Improve inference in PlusOperator (#269)
* Tuples in PlusOperator * Add inference tests * Generalize collateop * Version bump to v0.7.36
1 parent 17017a1 commit 67978db

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ApproxFunBase"
22
uuid = "fbd15aa5-315a-5a7d-a8a4-24992e37be05"
3-
version = "0.7.35"
3+
version = "0.7.36"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/Operators/general/algebra.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ function convert(::Type{Operator{T}}, P::PlusOperator) where T
4646
P
4747
else
4848
ops = P.ops
49-
PlusOperator(ops isa AbstractVector{<:Operator{T}} ? ops : map(x -> strictconvert(Operator{T}, x), ops),
49+
PlusOperator(ops isa AbstractVector{<:Operator{T}} ? ops :
50+
map(x -> strictconvert(Operator{T}, x), ops),
5051
P.bandwidths,P.sz)::Operator{T}
5152
end
5253
end
@@ -64,19 +65,19 @@ end
6465
domain(P::PlusOperator) = commondomain(P.ops)
6566

6667
_promote_eltypeof(As...) = _promote_eltypeof(As)
67-
_promote_eltypeof(As::Union{Vector, Tuple}) = mapreduce(eltype, promote_type, As)
68-
_promote_eltypeof(As::Vector{Operator{T}}) where {T} = T
68+
_promote_eltypeof(As::Union{AbstractVector, Tuple}) = mapreduce(eltype, promote_type, As)
69+
_promote_eltypeof(As::AbstractVector{<:Operator{T}}) where {T} = T
6970

70-
_extractops(A, ::Any) = [A]
71+
_extractops(A, ::Any) = SVector{1}(A)
7172
_extractops(A::PlusOperator, ::typeof(+)) = A.ops
7273

7374
function +(A::Operator,B::Operator)
74-
v = [_extractops(A, +); _extractops(B, +)]
75+
v = collateops(+, A, B)
7576
promoteplus(v, size(A))
7677
end
7778
# Optimization for 3-term sum
7879
function +(A::Operator,B::Operator,C::Operator)
79-
v = [_extractops(A,+); _extractops(B, +); _extractops(C, +)]
80+
v = collateops(+, A, B, C)
8081
promoteplus(v, size(A))
8182
end
8283

@@ -247,10 +248,10 @@ __bandwidthssum(A, B::NTuple{2,InfiniteCardinal{0}}) = B
247248
__bandwidthssum(A, B) = reduce((t1, t2) -> t1 .+ t2, (A, B), init = (0,0))
248249

249250
_timessize(ops) = (size(first(ops),1), size(last(ops),2))
250-
function TimesOperator(ops::Vector{O},
251+
function TimesOperator(ops::AbstractVector{O},
251252
bi::Tuple{Any,Any} = bandwidthssum(ops),
252253
sz::Tuple{Any,Any} = _timessize(ops)) where {T,O<:Operator{T}}
253-
TimesOperator{T,typeof(bi),typeof(sz),O}(ops,bi,sz)
254+
TimesOperator{T,typeof(bi),typeof(sz),O}(convert_vector(ops),bi,sz)
254255
end
255256

256257
_extractops(A::TimesOperator, ::typeof(*)) = A.ops
@@ -268,7 +269,8 @@ function convert(::Type{Operator{T}},P::TimesOperator) where T
268269
P
269270
else
270271
ops = P.ops
271-
TimesOperator(ops isa AbstractVector{<:Operator{T}} ? ops : map(x -> strictconvert(Operator{T}, x), ops) ,
272+
TimesOperator(ops isa AbstractVector{<:Operator{T}} ? ops :
273+
map(x -> strictconvert(Operator{T}, x), ops) ,
272274
bandwidths(P), size(P))
273275
end
274276
end
@@ -517,17 +519,22 @@ for OP in (:(adjoint),:(transpose))
517519
reverse(bandwidths(A)), reverse(size(A)))
518520
end
519521

520-
_collateops(A::TimesOperator, B::TimesOperator, ::typeof(*)) = [_extractops(A, *); _extractops(B, *)]
521-
_collateops(A::TimesOperator, B::Operator, ::typeof(*)) = [_extractops(A, *); _extractops(B, *)]
522-
_collateops(A::Operator, B::TimesOperator, ::typeof(*)) = [_extractops(A, *); _extractops(B, *)]
523-
_collateops(A::Operator, B::Operator, ::typeof(*)) = (A, B)
522+
const PlusOrTimesOp = Union{PlusOperator, TimesOperator}
523+
anyplustimes(op::Operator, ops...) = anyplustimes(ops...)
524+
anyplustimes(op::PlusOrTimesOp, ops...) = true
525+
anyplustimes() = false
526+
527+
collateops(op, As...) = collateops(op, Val(anyplustimes(As...)), As...)
528+
collateops(op, ::Val{true}, As...) = mapreduce(x -> _extractops(x, op), vcat, As)
529+
collateops(op, ::Val{false}, As...) = As
530+
524531
function *(A::Operator,B::Operator)
525532
if isconstop(A)
526533
promoterangespace(strictconvert(Number,A)*B,rangespace(A))
527534
elseif isconstop(B)
528535
promotedomainspace(strictconvert(Number,B)*A,domainspace(B))
529536
else
530-
promotetimes(_collateops(A, B, *),
537+
promotetimes(collateops(*, A, B),
531538
domainspace(B), _timessize((A,B)), false)
532539
end
533540
end

test/runtests.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ end
213213
f = Fun(PointSpace(1:3), c)
214214
M = Multiplication(f)
215215
@testset for t in [1, 3]
216-
op = M + t * M
216+
op = @inferred M + t * M
217217
@test bandwidths(op) == bandwidths(M)
218218
@test coefficients(op * f) == @. (1+t)*c^2
219219
for op2 in Any[M + M + t * M, op + M]
@@ -230,9 +230,11 @@ end
230230
@test coefficients(f1) == coefficients(f2) == coefficients(f3)
231231
end
232232
Z = ApproxFunBase.ZeroOperator()
233-
@test Z + Z == Z
234-
@test Z + Z + Z == Z
235-
@test Z + Z + Z + Z == Z
233+
@test (@inferred Z + Z) == Z
234+
@test (@inferred Z + Z + Z) == Z
235+
@test (@inferred Z + Z + Z + Z) == Z
236+
237+
@inferred (() -> (D = Derivative(); D + D))()
236238
end
237239
end
238240

0 commit comments

Comments
 (0)