Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,11 @@ end

# constructors
for T in SCALINGNUMBERTYPES[2:end]
@eval ScaledOperator(λ::$T, L::AbstractSciMLOperator) = ScaledOperator(
ScalarOperator(λ),
L)
@eval function ScaledOperator(λ::$T, L::AbstractSciMLOperator)
T2 = Base.promote_eltype(λ, L)
Λ = λ isa UniformScaling ? UniformScaling(T2(λ.λ)) : T2(λ)
ScaledOperator(ScalarOperator(Λ), L)
end
end

for T in SCALINGNUMBERTYPES
Expand Down Expand Up @@ -276,18 +278,16 @@ for T in SCALINGNUMBERTYPES[2:end]
isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L)
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
end
@eval function Base.:*(α::$T, L::MatrixOperator)
isconstant(L) && return MatrixOperator(α * L.A)
return ScaledOperator(α, L) # Going back to the generic case
end
@eval function Base.:*(L::MatrixOperator, α::$T)
isconstant(L) && return MatrixOperator(α * L.A)
return ScaledOperator(α, L) # Going back to the generic case
end
end

Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L)
Base.:+(L::AbstractSciMLOperator) = L
Base.:-(L::AbstractSciMLOperator{T}) where T = ScaledOperator(-one(T), L)

# Special cases for constant scalars. These simplify the structure when applicable
function Base.:-(L::ScaledOperator)
isconstant(L.λ) && return ScaledOperator(-L.λ, L.L)
return ScaledOperator(L.λ, -L.L) # Try to propagate the rule
end

function Base.convert(::Type{AbstractMatrix}, L::ScaledOperator)
convert(Number, L.λ) * convert(AbstractMatrix, L.L)
Expand Down Expand Up @@ -428,9 +428,11 @@ struct AddedOperator{T,

function AddedOperator(ops)
@assert !isempty(ops)
_check_AddedOperator_sizes(ops)
T = mapreduce(eltype, promote_type, ops)
new{T, typeof(ops)}(ops)
# Flatten nested AddedOperators
ops_flat = _flatten_added_operators(ops)
_check_AddedOperator_sizes(ops_flat)
T = mapreduce(eltype, promote_type, ops_flat)
new{T, typeof(ops_flat)}(ops_flat)
end
end

Expand All @@ -440,6 +442,25 @@ end

AddedOperator(L::AbstractSciMLOperator) = L

# Helper function to flatten nested AddedOperators
@generated function _flatten_added_operators(ops::Tuple)
exprs = ()
for i in 1:length(ops.parameters)
T = ops.parameters[i]
if T <: AddedOperator
# If this element is an AddedOperator, unpack its ops
exprs = (exprs..., :(ops[$i].ops...))
else
# Otherwise, keep the element as-is
exprs = (exprs..., :(ops[$i]))
end
end

return quote
tuple($(exprs...))
end
end

@generated function _check_AddedOperator_sizes(ops::Tuple)
ops_types = ops.parameters
N = length(ops_types)
Expand Down
35 changes: 30 additions & 5 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ end
end
end

@testset "Unary +/-" begin
A = MatrixOperator(rand(N, N))
v = rand(N, K)

# Test unary +
@test +A === A

# Test unary - on constant MatrixOperator (simplified to MatrixOperator)
minusA = -A
@test minusA isa ScaledOperator
@test minusA * v ≈ -A.A * v
@test eltype(minusA.λ) == eltype(A.A)
end

@testset "ScaledOperator" begin
A = rand(N, N)
D = Diagonal(rand(N))
Expand Down Expand Up @@ -269,13 +283,24 @@ end
w_orig = copy(v)
@test mul!(v, op, u, α, β) ≈ α * (A + B) * u + β * w_orig

# ensure AddedOperator doesn't nest
# Test flattening of nested AddedOperators via direct constructor
A = MatrixOperator(rand(N, N))
L = A + (A + A) + A
B = MatrixOperator(rand(N, N))
C = MatrixOperator(rand(N, N))

# Create nested structure: (A + B) is an AddedOperator
AB = A + B
@test AB isa AddedOperator

# When we create AddedOperator((AB, C)), it should flatten
L = AddedOperator((AB, C))
@test L isa AddedOperator
for op in L.ops
@test !isa(op, AddedOperator)
end
@test length(L.ops) == 3 # Should have A, B, C (not AB and C)
@test all(op -> !isa(op, AddedOperator), L.ops)

# Verify correctness
test_vec = rand(N, K)
@test L * test_vec ≈ (A + B + C) * test_vec

## Time-Dependent Coefficients
for T in (Float32, Float64, ComplexF32, ComplexF64)
Expand Down
Loading