Skip to content

Commit 285de13

Browse files
Fix type instabilities on complex AddedOperators and add new methods
1 parent 4937fe4 commit 285de13

File tree

2 files changed

+23
-26
lines changed

2 files changed

+23
-26
lines changed

src/basic.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ struct AddedOperator{T,
429429
function AddedOperator(ops)
430430
@assert !isempty(ops)
431431
_check_AddedOperator_sizes(ops)
432-
T = promote_type(eltype.(ops)...)
432+
T = mapreduce(eltype, promote_type, ops)
433433
new{T, typeof(ops)}(ops)
434434
end
435435
end
@@ -476,9 +476,13 @@ function Base.:+(Z::NullOperator, A::AddedOperator)
476476
A
477477
end
478478

479+
Base.:-(A::AddedOperator) = AddedOperator(map(-, A.ops))
479480
Base.:-(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = AddedOperator(A, -B)
480481
Base.:-(A::AbstractSciMLOperator, B::AbstractMatrix) = A - MatrixOperator(B)
481482
Base.:-(A::AbstractMatrix, B::AbstractSciMLOperator) = MatrixOperator(A) - B
483+
Base.:-(A::AddedOperator, B::AbstractSciMLOperator) = AddedOperator(A.ops..., -B)
484+
Base.:-(A::AbstractSciMLOperator, B::AddedOperator) = AddedOperator(A, (-B).ops...)
485+
Base.:-(A::AddedOperator, B::AddedOperator) = AddedOperator(A.ops..., (-B).ops...)
482486

483487
for op in (:+, :-)
484488
for T in SCALINGNUMBERTYPES

test/basic.jl

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -278,33 +278,26 @@ end
278278
end
279279

280280
## Time-Dependent Coefficients
281-
282281
for T in (Float32, Float64, ComplexF32, ComplexF64)
283282
N = 100
284-
A1_sparse = MatrixOperator(sprand(T, N, N, 5 / N))
285-
A2_sparse = MatrixOperator(sprand(T, N, N, 5 / N))
286-
A3_sparse = MatrixOperator(sprand(T, N, N, 5 / N))
287-
288-
A1_dense = MatrixOperator(rand(T, N, N))
289-
A2_dense = MatrixOperator(rand(T, N, N))
290-
A3_dense = MatrixOperator(rand(T, N, N))
291-
292-
coeff1(a, u, p, t) = sin(p.ω * t)
293-
coeff2(a, u, p, t) = cos(p.ω * t)
294-
coeff3(a, u, p, t) = sin(p.ω * t) * cos(p.ω * t)
295-
296-
c1 = ScalarOperator(rand(T), coeff1)
297-
c2 = ScalarOperator(rand(T), coeff2)
298-
c3 = ScalarOperator(rand(T), coeff3)
299-
300-
H_sparse = c1 * A1_sparse + c2 * A2_sparse + c3 * A3_sparse
301-
H_dense = c1 * A1_dense + c2 * A2_dense + c3 * A3_dense
302-
303-
u = rand(T, N)
304-
v = rand(T, N)
305-
du = similar(u)
306-
p == 0.1,)
307-
t = 0.1
283+
A = sprand(T, N, N, 2 / N)
284+
285+
func1(a, u, p, t) = t
286+
func2(a, u, p, t) = t^2
287+
func3(a, u, p, t) = t^3
288+
func4(a, u, p, t) = t^4
289+
func5(a, u, p, t) = t^5
290+
291+
O1 = MatrixOperator(A) + ScalarOperator(0.0, func1) * MatrixOperator(A) + ScalarOperator(0.0, func2) * MatrixOperator(A)
292+
293+
O2 = MatrixOperator(A) + ScalarOperator(0.0, func3) * MatrixOperator(A) + ScalarOperator(0.0, func4) * MatrixOperator(A)
294+
295+
O3 = MatrixOperator(A) + ScalarOperator(0.0, func5) * MatrixOperator(A)
296+
297+
Op = -1im * (O1 - O2)
298+
299+
@test length(Op.ops) == length(O1.ops) + length(O2.ops)
300+
@inferred Op + O3
308301
end
309302
end
310303

0 commit comments

Comments
 (0)