Skip to content

Commit 73f73ac

Browse files
Merge pull request #321 from albertomercurio/master
Add operator simplification for unary negation and flatten nested AddedOperators
2 parents 8181601 + 6454d33 commit 73f73ac

File tree

2 files changed

+66
-20
lines changed

2 files changed

+66
-20
lines changed

src/basic.jl

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -243,9 +243,11 @@ end
243243

244244
# constructors
245245
for T in SCALINGNUMBERTYPES[2:end]
246-
@eval ScaledOperator::$T, L::AbstractSciMLOperator) = ScaledOperator(
247-
ScalarOperator(λ),
248-
L)
246+
@eval function ScaledOperator::$T, L::AbstractSciMLOperator)
247+
T2 = Base.promote_eltype(λ, L)
248+
Λ = λ isa UniformScaling ? UniformScaling(T2.λ)) : T2(λ)
249+
ScaledOperator(ScalarOperator(Λ), L)
250+
end
249251
end
250252

251253
for T in SCALINGNUMBERTYPES
@@ -276,18 +278,16 @@ for T in SCALINGNUMBERTYPES[2:end]
276278
isconstant(L.λ) && return ScaledOperator* L.λ, L.L)
277279
return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule
278280
end
279-
@eval function Base.:*::$T, L::MatrixOperator)
280-
isconstant(L) && return MatrixOperator* L.A)
281-
return ScaledOperator(α, L) # Going back to the generic case
282-
end
283-
@eval function Base.:*(L::MatrixOperator, α::$T)
284-
isconstant(L) && return MatrixOperator* L.A)
285-
return ScaledOperator(α, L) # Going back to the generic case
286-
end
287281
end
288282

289-
Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L)
290283
Base.:+(L::AbstractSciMLOperator) = L
284+
Base.:-(L::AbstractSciMLOperator{T}) where T = ScaledOperator(-one(T), L)
285+
286+
# Special cases for constant scalars. These simplify the structure when applicable
287+
function Base.:-(L::ScaledOperator)
288+
isconstant(L.λ) && return ScaledOperator(-L.λ, L.L)
289+
return ScaledOperator(L.λ, -L.L) # Try to propagate the rule
290+
end
291291

292292
function Base.convert(::Type{AbstractMatrix}, L::ScaledOperator)
293293
convert(Number, L.λ) * convert(AbstractMatrix, L.L)
@@ -428,9 +428,11 @@ struct AddedOperator{T,
428428

429429
function AddedOperator(ops)
430430
@assert !isempty(ops)
431-
_check_AddedOperator_sizes(ops)
432-
T = mapreduce(eltype, promote_type, ops)
433-
new{T, typeof(ops)}(ops)
431+
# Flatten nested AddedOperators
432+
ops_flat = _flatten_added_operators(ops)
433+
_check_AddedOperator_sizes(ops_flat)
434+
T = mapreduce(eltype, promote_type, ops_flat)
435+
new{T, typeof(ops_flat)}(ops_flat)
434436
end
435437
end
436438

@@ -440,6 +442,25 @@ end
440442

441443
AddedOperator(L::AbstractSciMLOperator) = L
442444

445+
# Helper function to flatten nested AddedOperators
446+
@generated function _flatten_added_operators(ops::Tuple)
447+
exprs = ()
448+
for i in 1:length(ops.parameters)
449+
T = ops.parameters[i]
450+
if T <: AddedOperator
451+
# If this element is an AddedOperator, unpack its ops
452+
exprs = (exprs..., :(ops[$i].ops...))
453+
else
454+
# Otherwise, keep the element as-is
455+
exprs = (exprs..., :(ops[$i]))
456+
end
457+
end
458+
459+
return quote
460+
tuple($(exprs...))
461+
end
462+
end
463+
443464
@generated function _check_AddedOperator_sizes(ops::Tuple)
444465
ops_types = ops.parameters
445466
N = length(ops_types)

test/basic.jl

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ end
141141
end
142142
end
143143

144+
@testset "Unary +/-" begin
145+
A = MatrixOperator(rand(N, N))
146+
v = rand(N, K)
147+
148+
# Test unary +
149+
@test +A === A
150+
151+
# Test unary - on constant MatrixOperator (simplified to MatrixOperator)
152+
minusA = -A
153+
@test minusA isa ScaledOperator
154+
@test minusA * v -A.A * v
155+
@test eltype(minusA.λ) == eltype(A.A)
156+
end
157+
144158
@testset "ScaledOperator" begin
145159
A = rand(N, N)
146160
D = Diagonal(rand(N))
@@ -269,13 +283,24 @@ end
269283
w_orig = copy(v)
270284
@test mul!(v, op, u, α, β) α * (A + B) * u + β * w_orig
271285

272-
# ensure AddedOperator doesn't nest
286+
# Test flattening of nested AddedOperators via direct constructor
273287
A = MatrixOperator(rand(N, N))
274-
L = A + (A + A) + A
288+
B = MatrixOperator(rand(N, N))
289+
C = MatrixOperator(rand(N, N))
290+
291+
# Create nested structure: (A + B) is an AddedOperator
292+
AB = A + B
293+
@test AB isa AddedOperator
294+
295+
# When we create AddedOperator((AB, C)), it should flatten
296+
L = AddedOperator((AB, C))
275297
@test L isa AddedOperator
276-
for op in L.ops
277-
@test !isa(op, AddedOperator)
278-
end
298+
@test length(L.ops) == 3 # Should have A, B, C (not AB and C)
299+
@test all(op -> !isa(op, AddedOperator), L.ops)
300+
301+
# Verify correctness
302+
test_vec = rand(N, K)
303+
@test L * test_vec (A + B + C) * test_vec
279304

280305
## Time-Dependent Coefficients
281306
for T in (Float32, Float64, ComplexF32, ComplexF64)

0 commit comments

Comments
 (0)