Skip to content

Commit 13014f9

Browse files
Add operator simplification for unary negation and flatten nested AddedOperators
1 parent 8181601 commit 13014f9

File tree

2 files changed

+80
-9
lines changed

2 files changed

+80
-9
lines changed

src/basic.jl

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,18 @@ for T in SCALINGNUMBERTYPES[2:end]
286286
end
287287
end
288288

289-
Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L)
290289
Base.:+(L::AbstractSciMLOperator) = L
290+
Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L)
291+
292+
# Special cases for constant scalars. These simplify the structure when applicable
293+
function Base.:-(L::ScaledOperator)
294+
isconstant(L.λ) && return ScaledOperator(-L.λ, L.L)
295+
return ScaledOperator(L.λ, -L.L) # Try to propagate the rule
296+
end
297+
function Base.:-(L::MatrixOperator)
298+
isconstant(L) && return MatrixOperator(-L.A)
299+
return ScaledOperator(-true, L) # Going back to the generic case
300+
end
291301

292302
function Base.convert(::Type{AbstractMatrix}, L::ScaledOperator)
293303
convert(Number, L.λ) * convert(AbstractMatrix, L.L)
@@ -428,9 +438,11 @@ struct AddedOperator{T,
428438

429439
function AddedOperator(ops)
430440
@assert !isempty(ops)
431-
_check_AddedOperator_sizes(ops)
432-
T = mapreduce(eltype, promote_type, ops)
433-
new{T, typeof(ops)}(ops)
441+
# Flatten nested AddedOperators
442+
ops_flat = _flatten_added_operators(ops)
443+
_check_AddedOperator_sizes(ops_flat)
444+
T = mapreduce(eltype, promote_type, ops_flat)
445+
new{T, typeof(ops_flat)}(ops_flat)
434446
end
435447
end
436448

@@ -440,6 +452,25 @@ end
440452

441453
AddedOperator(L::AbstractSciMLOperator) = L
442454

455+
# Helper function to flatten nested AddedOperators
456+
@generated function _flatten_added_operators(ops::Tuple)
457+
exprs = ()
458+
for i in 1:length(ops.parameters)
459+
T = ops.parameters[i]
460+
if T <: AddedOperator
461+
# If this element is an AddedOperator, unpack its ops
462+
exprs = (exprs..., :(ops[$i].ops...))
463+
else
464+
# Otherwise, keep the element as-is
465+
exprs = (exprs..., :(ops[$i]))
466+
end
467+
end
468+
469+
return quote
470+
tuple($(exprs...))
471+
end
472+
end
473+
443474
@generated function _check_AddedOperator_sizes(ops::Tuple)
444475
ops_types = ops.parameters
445476
N = length(ops_types)

test/basic.jl

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,35 @@ end
141141
end
142142
end
143143

144+
@testset "Unary +/-" begin
145+
A = MatrixOperator(rand(N, N))
146+
v = rand(N, K)
147+
148+
# Test unary +
149+
@test +A === A
150+
151+
# Test unary - on constant MatrixOperator (simplified to MatrixOperator)
152+
minusA = -A
153+
@test minusA isa MatrixOperator
154+
@test minusA * v -A.A * v
155+
156+
# Test unary - on non-constant MatrixOperator (falls back to ScaledOperator)
157+
func(A, u, p, t) = t
158+
timeDepA = MatrixOperator(A.A; update_func = func)
159+
minusTimeDepA = -timeDepA
160+
@test minusTimeDepA isa ScaledOperator # Can't simplify, uses generic fallback
161+
162+
# Test unary - on non-constant ScaledOperator (propagates to inner operator)
163+
timeDepScaled = ScalarOperator(0.0, func) * A
164+
@test !isconstant(timeDepScaled)
165+
minusTimeDepScaled = -timeDepScaled
166+
@test minusTimeDepScaled.λ isa ScalarOperator && minusTimeDepScaled.L isa MatrixOperator # Propagates negation to inner MatrixOperator
167+
168+
# Test double negation
169+
@test -(-A) isa MatrixOperator
170+
@test (-(-A)) * v A * v
171+
end
172+
144173
@testset "ScaledOperator" begin
145174
A = rand(N, N)
146175
D = Diagonal(rand(N))
@@ -269,13 +298,24 @@ end
269298
w_orig = copy(v)
270299
@test mul!(v, op, u, α, β) α * (A + B) * u + β * w_orig
271300

272-
# ensure AddedOperator doesn't nest
301+
# Test flattening of nested AddedOperators via direct constructor
273302
A = MatrixOperator(rand(N, N))
274-
L = A + (A + A) + A
303+
B = MatrixOperator(rand(N, N))
304+
C = MatrixOperator(rand(N, N))
305+
306+
# Create nested structure: (A + B) is an AddedOperator
307+
AB = A + B
308+
@test AB isa AddedOperator
309+
310+
# When we create AddedOperator((AB, C)), it should flatten
311+
L = AddedOperator((AB, C))
275312
@test L isa AddedOperator
276-
for op in L.ops
277-
@test !isa(op, AddedOperator)
278-
end
313+
@test length(L.ops) == 3 # Should have A, B, C (not AB and C)
314+
@test all(op -> !isa(op, AddedOperator), L.ops)
315+
316+
# Verify correctness
317+
test_vec = rand(N, K)
318+
@test L * test_vec (A + B + C) * test_vec
279319

280320
## Time-Dependent Coefficients
281321
for T in (Float32, Float64, ComplexF32, ComplexF64)

0 commit comments

Comments
 (0)