From 285de132a96bc151d941104b5b139a7b006f0d3b Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Tue, 9 Sep 2025 16:33:54 +0200 Subject: [PATCH] Fix type instabilities on complex AddedOperators and add new methods --- src/basic.jl | 6 +++++- test/basic.jl | 43 ++++++++++++++++++------------------------- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/basic.jl b/src/basic.jl index b6f6a47a..e0f5291b 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -429,7 +429,7 @@ struct AddedOperator{T, function AddedOperator(ops) @assert !isempty(ops) _check_AddedOperator_sizes(ops) - T = promote_type(eltype.(ops)...) + T = mapreduce(eltype, promote_type, ops) new{T, typeof(ops)}(ops) end end @@ -476,9 +476,13 @@ function Base.:+(Z::NullOperator, A::AddedOperator) A end +Base.:-(A::AddedOperator) = AddedOperator(map(-, A.ops)) Base.:-(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = AddedOperator(A, -B) Base.:-(A::AbstractSciMLOperator, B::AbstractMatrix) = A - MatrixOperator(B) Base.:-(A::AbstractMatrix, B::AbstractSciMLOperator) = MatrixOperator(A) - B +Base.:-(A::AddedOperator, B::AbstractSciMLOperator) = AddedOperator(A.ops..., -B) +Base.:-(A::AbstractSciMLOperator, B::AddedOperator) = AddedOperator(A, (-B).ops...) +Base.:-(A::AddedOperator, B::AddedOperator) = AddedOperator(A.ops..., (-B).ops...) for op in (:+, :-) for T in SCALINGNUMBERTYPES diff --git a/test/basic.jl b/test/basic.jl index cf76ddd3..f6c9e314 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -278,33 +278,26 @@ end end ## Time-Dependent Coefficients - for T in (Float32, Float64, ComplexF32, ComplexF64) N = 100 - A1_sparse = MatrixOperator(sprand(T, N, N, 5 / N)) - A2_sparse = MatrixOperator(sprand(T, N, N, 5 / N)) - A3_sparse = MatrixOperator(sprand(T, N, N, 5 / N)) - - A1_dense = MatrixOperator(rand(T, N, N)) - A2_dense = MatrixOperator(rand(T, N, N)) - A3_dense = MatrixOperator(rand(T, N, N)) - - coeff1(a, u, p, t) = sin(p.ω * t) - coeff2(a, u, p, t) = cos(p.ω * t) - coeff3(a, u, p, t) = sin(p.ω * t) * cos(p.ω * t) - - c1 = ScalarOperator(rand(T), coeff1) - c2 = ScalarOperator(rand(T), coeff2) - c3 = ScalarOperator(rand(T), coeff3) - - H_sparse = c1 * A1_sparse + c2 * A2_sparse + c3 * A3_sparse - H_dense = c1 * A1_dense + c2 * A2_dense + c3 * A3_dense - - u = rand(T, N) - v = rand(T, N) - du = similar(u) - p = (ω = 0.1,) - t = 0.1 + A = sprand(T, N, N, 2 / N) + + func1(a, u, p, t) = t + func2(a, u, p, t) = t^2 + func3(a, u, p, t) = t^3 + func4(a, u, p, t) = t^4 + func5(a, u, p, t) = t^5 + + O1 = MatrixOperator(A) + ScalarOperator(0.0, func1) * MatrixOperator(A) + ScalarOperator(0.0, func2) * MatrixOperator(A) + + O2 = MatrixOperator(A) + ScalarOperator(0.0, func3) * MatrixOperator(A) + ScalarOperator(0.0, func4) * MatrixOperator(A) + + O3 = MatrixOperator(A) + ScalarOperator(0.0, func5) * MatrixOperator(A) + + Op = -1im * (O1 - O2) + + @test length(Op.ops) == length(O1.ops) + length(O2.ops) + @inferred Op + O3 end end