diff --git a/src/basic.jl b/src/basic.jl index 05868889..bfb21644 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -243,9 +243,11 @@ end # constructors for T in SCALINGNUMBERTYPES[2:end] - @eval ScaledOperator(λ::$T, L::AbstractSciMLOperator) = ScaledOperator( - ScalarOperator(λ), - L) + @eval function ScaledOperator(λ::$T, L::AbstractSciMLOperator) + T2 = Base.promote_eltype(λ, L) + Λ = λ isa UniformScaling ? UniformScaling(T2(λ.λ)) : T2(λ) + ScaledOperator(ScalarOperator(Λ), L) + end end for T in SCALINGNUMBERTYPES @@ -276,18 +278,16 @@ for T in SCALINGNUMBERTYPES[2:end] isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule end - @eval function Base.:*(α::$T, L::MatrixOperator) - isconstant(L) && return MatrixOperator(α * L.A) - return ScaledOperator(α, L) # Going back to the generic case - end - @eval function Base.:*(L::MatrixOperator, α::$T) - isconstant(L) && return MatrixOperator(α * L.A) - return ScaledOperator(α, L) # Going back to the generic case - end end -Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L) Base.:+(L::AbstractSciMLOperator) = L +Base.:-(L::AbstractSciMLOperator{T}) where T = ScaledOperator(-one(T), L) + +# Special cases for constant scalars. These simplify the structure when applicable +function Base.:-(L::ScaledOperator) + isconstant(L.λ) && return ScaledOperator(-L.λ, L.L) + return ScaledOperator(L.λ, -L.L) # Try to propagate the rule +end function Base.convert(::Type{AbstractMatrix}, L::ScaledOperator) convert(Number, L.λ) * convert(AbstractMatrix, L.L) @@ -428,9 +428,11 @@ struct AddedOperator{T, function AddedOperator(ops) @assert !isempty(ops) - _check_AddedOperator_sizes(ops) - T = mapreduce(eltype, promote_type, ops) - new{T, typeof(ops)}(ops) + # Flatten nested AddedOperators + ops_flat = _flatten_added_operators(ops) + _check_AddedOperator_sizes(ops_flat) + T = mapreduce(eltype, promote_type, ops_flat) + new{T, typeof(ops_flat)}(ops_flat) end end @@ -440,6 +442,25 @@ end AddedOperator(L::AbstractSciMLOperator) = L +# Helper function to flatten nested AddedOperators +@generated function _flatten_added_operators(ops::Tuple) + exprs = () + for i in 1:length(ops.parameters) + T = ops.parameters[i] + if T <: AddedOperator + # If this element is an AddedOperator, unpack its ops + exprs = (exprs..., :(ops[$i].ops...)) + else + # Otherwise, keep the element as-is + exprs = (exprs..., :(ops[$i])) + end + end + + return quote + tuple($(exprs...)) + end +end + @generated function _check_AddedOperator_sizes(ops::Tuple) ops_types = ops.parameters N = length(ops_types) diff --git a/test/basic.jl b/test/basic.jl index f6c9e314..f941ba38 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -141,6 +141,20 @@ end end end +@testset "Unary +/-" begin + A = MatrixOperator(rand(N, N)) + v = rand(N, K) + + # Test unary + + @test +A === A + + # Test unary - on constant MatrixOperator (simplified to MatrixOperator) + minusA = -A + @test minusA isa ScaledOperator + @test minusA * v ≈ -A.A * v + @test eltype(minusA.λ) == eltype(A.A) +end + @testset "ScaledOperator" begin A = rand(N, N) D = Diagonal(rand(N)) @@ -269,13 +283,24 @@ end w_orig = copy(v) @test mul!(v, op, u, α, β) ≈ α * (A + B) * u + β * w_orig - # ensure AddedOperator doesn't nest + # Test flattening of nested AddedOperators via direct constructor A = MatrixOperator(rand(N, N)) - L = A + (A + A) + A + B = MatrixOperator(rand(N, N)) + C = MatrixOperator(rand(N, N)) + + # Create nested structure: (A + B) is an AddedOperator + AB = A + B + @test AB isa AddedOperator + + # When we create AddedOperator((AB, C)), it should flatten + L = AddedOperator((AB, C)) @test L isa AddedOperator - for op in L.ops - @test !isa(op, AddedOperator) - end + @test length(L.ops) == 3 # Should have A, B, C (not AB and C) + @test all(op -> !isa(op, AddedOperator), L.ops) + + # Verify correctness + test_vec = rand(N, K) + @test L * test_vec ≈ (A + B + C) * test_vec ## Time-Dependent Coefficients for T in (Float32, Float64, ComplexF32, ComplexF64)