|
243 | 243 |
|
244 | 244 | # constructors |
245 | 245 | for T in SCALINGNUMBERTYPES[2:end] |
246 | | - @eval ScaledOperator(λ::$T, L::AbstractSciMLOperator) = ScaledOperator( |
247 | | - ScalarOperator(λ), |
248 | | - L) |
| 246 | + @eval function ScaledOperator(λ::$T, L::AbstractSciMLOperator) |
| 247 | + T2 = Base.promote_eltype(λ, L) |
| 248 | + Λ = λ isa UniformScaling ? UniformScaling(T2(λ.λ)) : T2(λ) |
| 249 | + ScaledOperator(ScalarOperator(Λ), L) |
| 250 | + end |
249 | 251 | end |
250 | 252 |
|
251 | 253 | for T in SCALINGNUMBERTYPES |
@@ -276,18 +278,16 @@ for T in SCALINGNUMBERTYPES[2:end] |
276 | 278 | isconstant(L.λ) && return ScaledOperator(α * L.λ, L.L) |
277 | 279 | return ScaledOperator(L.λ, α * L.L) # Try to propagate the rule |
278 | 280 | end |
279 | | - @eval function Base.:*(α::$T, L::MatrixOperator) |
280 | | - isconstant(L) && return MatrixOperator(α * L.A) |
281 | | - return ScaledOperator(α, L) # Going back to the generic case |
282 | | - end |
283 | | - @eval function Base.:*(L::MatrixOperator, α::$T) |
284 | | - isconstant(L) && return MatrixOperator(α * L.A) |
285 | | - return ScaledOperator(α, L) # Going back to the generic case |
286 | | - end |
287 | 281 | end |
288 | 282 |
|
289 | | -Base.:-(L::AbstractSciMLOperator) = ScaledOperator(-true, L) |
290 | 283 | Base.:+(L::AbstractSciMLOperator) = L |
| 284 | +Base.:-(L::AbstractSciMLOperator{T}) where T = ScaledOperator(-one(T), L) |
| 285 | + |
| 286 | +# Special cases for constant scalars. These simplify the structure when applicable |
| 287 | +function Base.:-(L::ScaledOperator) |
| 288 | + isconstant(L.λ) && return ScaledOperator(-L.λ, L.L) |
| 289 | + return ScaledOperator(L.λ, -L.L) # Try to propagate the rule |
| 290 | +end |
291 | 291 |
|
292 | 292 | function Base.convert(::Type{AbstractMatrix}, L::ScaledOperator) |
293 | 293 | convert(Number, L.λ) * convert(AbstractMatrix, L.L) |
@@ -428,9 +428,11 @@ struct AddedOperator{T, |
428 | 428 |
|
429 | 429 | function AddedOperator(ops) |
430 | 430 | @assert !isempty(ops) |
431 | | - _check_AddedOperator_sizes(ops) |
432 | | - T = mapreduce(eltype, promote_type, ops) |
433 | | - new{T, typeof(ops)}(ops) |
| 431 | + # Flatten nested AddedOperators |
| 432 | + ops_flat = _flatten_added_operators(ops) |
| 433 | + _check_AddedOperator_sizes(ops_flat) |
| 434 | + T = mapreduce(eltype, promote_type, ops_flat) |
| 435 | + new{T, typeof(ops_flat)}(ops_flat) |
434 | 436 | end |
435 | 437 | end |
436 | 438 |
|
|
440 | 442 |
|
441 | 443 | AddedOperator(L::AbstractSciMLOperator) = L |
442 | 444 |
|
| 445 | +# Helper function to flatten nested AddedOperators |
| 446 | +@generated function _flatten_added_operators(ops::Tuple) |
| 447 | + exprs = () |
| 448 | + for i in 1:length(ops.parameters) |
| 449 | + T = ops.parameters[i] |
| 450 | + if T <: AddedOperator |
| 451 | + # If this element is an AddedOperator, unpack its ops |
| 452 | + exprs = (exprs..., :(ops[$i].ops...)) |
| 453 | + else |
| 454 | + # Otherwise, keep the element as-is |
| 455 | + exprs = (exprs..., :(ops[$i])) |
| 456 | + end |
| 457 | + end |
| 458 | + |
| 459 | + return quote |
| 460 | + tuple($(exprs...)) |
| 461 | + end |
| 462 | +end |
| 463 | + |
443 | 464 | @generated function _check_AddedOperator_sizes(ops::Tuple) |
444 | 465 | ops_types = ops.parameters |
445 | 466 | N = length(ops_types) |
|
0 commit comments