Skip to content

Commit d12111a

Browse files
authored
Fix bug in TimesOperator constructor, and add Multiplication equality (#110)
* performance optimizations in TimesOperator * add equality methods and tests
1 parent eeb33d4 commit d12111a

File tree

4 files changed

+47
-30
lines changed

4 files changed

+47
-30
lines changed

src/Operators/banded/Multiplication.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ function ConcreteMultiplication(f::Fun{D,T},sp::Space) where {D,T}
2626
ConcreteMultiplication{D,typeof(sp),V}(convert(Fun{D,V},chop(f,40*eps(cfstype(f)))),sp)
2727
end
2828

29+
==(A::ConcreteMultiplication, B::ConcreteMultiplication) = (A.f == B.f) && (A.space == B.space)
30+
2931
# We do this in two stages to support Modifier spaces
3032
# without ambiguity errors
3133
function defaultMultiplication(f::Fun,sp::Space)

src/Operators/general/FiniteOperator.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ FiniteOperator(M::AbstractMatrix{<:Number}) =
1919
convert(::Type{Operator{T}},F::FiniteOperator) where {T} =
2020
FiniteOperator(convert(AbstractMatrix{T},F.matrix),F.domainspace,F.rangespace)::Operator{T}
2121

22+
==(A::FiniteOperator, B::FiniteOperator) = A.matrix == B.matrix && A.domainspace == B.domainspace && A.rangespace == B.rangespace
2223

2324
Base.promote_rule(::Type{OT},::Type{MT}) where {OT<:Operator,MT<:AbstractMatrix} = Operator{promote_type(eltype(OT),eltype(MT))}
2425

src/Operators/general/algebra.jl

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -228,45 +228,27 @@ struct TimesOperator{T,BI} <: Operator{T}
228228
function TimesOperator{T,BI}(ops::Vector{Operator{T}},bi::BI) where {T,BI}
229229
# check compatible
230230
for k=1:length(ops)-1
231-
@assert size(ops[k],2) == size(ops[k+1],1)
231+
size(ops[k],2) == size(ops[k+1],1) || throw(ArgumentError("incompatible operator sizes"))
232+
spacescompatible(domainspace(ops[k]),rangespace(ops[k+1])) || throw(ArgumentError("imcompatible spaces at index $k"))
232233
end
233234

234235
# remove TimesOperators buried inside ops
235-
hastimes = false
236-
for k=1:length(ops)-1
237-
@assert spacescompatible(domainspace(ops[k]),rangespace(ops[k+1]))
238-
hastimes = hastimes || isa(ops[k],TimesOperator)
239-
end
240-
241-
if hastimes
242-
newops=Vector{Operator{T}}(0)
243-
for op in ops
244-
if isa(op,TimesOperator)
245-
for op2 in op.ops
246-
push!(newops,op2)
247-
end
248-
else
249-
push!(newops,op)
250-
end
236+
timesinds = findall(x -> isa(x, TimesOperator), ops)
237+
if !isempty(timesinds)
238+
newops = copy(ops)
239+
for ind in timesinds
240+
splice!(newops, ind, ops[ind].ops)
251241
end
252-
ops=newops
242+
else
243+
newops = ops
253244
end
254245

255-
256-
new{T,BI}(ops,bi)
246+
new{T,BI}(newops,bi)
257247
end
258248
end
259249

260-
261-
function bandwidthsum(P,k)
262-
ret=0
263-
for op in P
264-
ret+=bandwidths(op)[k]
265-
end
266-
ret
267-
end
268-
269-
bandwidthssum(P) = (bandwidthsum(P,1),bandwidthsum(P,2))
250+
bandwidthssum(P, k) = bandwidthssum(P)[k]
251+
bandwidthssum(P) = mapreduce(bandwidths, (t1, t2) -> t1 .+ t2, P, init = (0,0))
270252

271253
TimesOperator(ops::Vector{Operator{T}},bi::Tuple{N1,N2}) where {T,N1,N2} =
272254
TimesOperator{T,typeof(bi)}(ops,bi)

test/runtests.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,36 @@ end
8282
@test ApproxFunBase.coefficients(f) === v
8383
end
8484

85+
@testset "operator algebra" begin
86+
@testset "Multiplication" begin
87+
sp = PointSpace(1:3)
88+
coeff = [1:3;]
89+
f = Fun(sp, coeff)
90+
for sp2 in Any[(), (sp,)]
91+
a = Multiplication(f, sp2...)
92+
b = Multiplication(f, sp2...)
93+
@test a == b
94+
@test bandwidths(a) == bandwidths(b)
95+
end
96+
end
97+
@testset "TimesOperator" begin
98+
sp = PointSpace(1:3)
99+
coeff = [1:3;]
100+
f = Fun(sp, coeff)
101+
for sp2 in Any[(), (sp,)]
102+
M = Multiplication(f, sp2...)
103+
a = (M * M) * M
104+
b = M * (M * M)
105+
@test a == b
106+
@test bandwidths(a) == bandwidths(b)
107+
end
108+
@testset "unwrap TimesOperator" begin
109+
M = Multiplication(f)
110+
for ops in Any[Operator{Float64}[M, M * M], Operator{Float64}[M*M, M]]
111+
@test TimesOperator(ops).ops == [M, M, M]
112+
end
113+
end
114+
end
115+
end
116+
85117
@time include("ETDRK4Test.jl")

0 commit comments

Comments
 (0)