|
141 | 141 | end |
142 | 142 | end |
143 | 143 |
|
| 144 | +@testset "Unary +/-" begin |
| 145 | + A = MatrixOperator(rand(N, N)) |
| 146 | + v = rand(N, K) |
| 147 | + |
| 148 | + # Test unary + |
| 149 | + @test +A === A |
| 150 | + |
| 151 | + # Test unary - on constant MatrixOperator (simplified to MatrixOperator) |
| 152 | + minusA = -A |
| 153 | + @test minusA isa MatrixOperator |
| 154 | + @test minusA * v ≈ -A.A * v |
| 155 | + |
| 156 | + # Test unary - on non-constant MatrixOperator (falls back to ScaledOperator) |
| 157 | + func(A, u, p, t) = t |
| 158 | + timeDepA = MatrixOperator(A.A; update_func = func) |
| 159 | + minusTimeDepA = -timeDepA |
| 160 | + @test minusTimeDepA isa ScaledOperator # Can't simplify, uses generic fallback |
| 161 | + |
| 162 | + # Test unary - on non-constant ScaledOperator (propagates to inner operator) |
| 163 | + timeDepScaled = ScalarOperator(0.0, func) * A |
| 164 | + @test !isconstant(timeDepScaled) |
| 165 | + minusTimeDepScaled = -timeDepScaled |
| 166 | + @test minusTimeDepScaled.λ isa ScalarOperator && minusTimeDepScaled.L isa MatrixOperator # Propagates negation to inner MatrixOperator |
| 167 | + |
| 168 | + # Test double negation |
| 169 | + @test -(-A) isa MatrixOperator |
| 170 | + @test (-(-A)) * v ≈ A * v |
| 171 | +end |
| 172 | + |
144 | 173 | @testset "ScaledOperator" begin |
145 | 174 | A = rand(N, N) |
146 | 175 | D = Diagonal(rand(N)) |
@@ -269,13 +298,24 @@ end |
269 | 298 | w_orig = copy(v) |
270 | 299 | @test mul!(v, op, u, α, β) ≈ α * (A + B) * u + β * w_orig |
271 | 300 |
|
272 | | - # ensure AddedOperator doesn't nest |
| 301 | + # Test flattening of nested AddedOperators via direct constructor |
273 | 302 | A = MatrixOperator(rand(N, N)) |
274 | | - L = A + (A + A) + A |
| 303 | + B = MatrixOperator(rand(N, N)) |
| 304 | + C = MatrixOperator(rand(N, N)) |
| 305 | + |
| 306 | + # Create nested structure: (A + B) is an AddedOperator |
| 307 | + AB = A + B |
| 308 | + @test AB isa AddedOperator |
| 309 | + |
| 310 | + # When we create AddedOperator((AB, C)), it should flatten |
| 311 | + L = AddedOperator((AB, C)) |
275 | 312 | @test L isa AddedOperator |
276 | | - for op in L.ops |
277 | | - @test !isa(op, AddedOperator) |
278 | | - end |
| 313 | + @test length(L.ops) == 3 # Should have A, B, C (not AB and C) |
| 314 | + @test all(op -> !isa(op, AddedOperator), L.ops) |
| 315 | + |
| 316 | + # Verify correctness |
| 317 | + test_vec = rand(N, K) |
| 318 | + @test L * test_vec ≈ (A + B + C) * test_vec |
279 | 319 |
|
280 | 320 | ## Time-Dependent Coefficients |
281 | 321 | for T in (Float32, Float64, ComplexF32, ComplexF64) |
|
0 commit comments