@@ -250,7 +250,7 @@ function update_coefficients!(L::ScaledOperator, u, p, t)
250250 update_coefficients! (L. L, u, p, t)
251251 update_coefficients! (L. λ, u, p, t)
252252
253- L
253+ nothing
254254end
255255
256256getops (L:: ScaledOperator ) = (L. λ, L. L)
@@ -327,22 +327,34 @@ struct AddedOperator{T,
327327
328328 function AddedOperator (ops)
329329 @assert ! isempty (ops)
330+ _check_AddedOperator_sizes (ops)
330331 T = promote_type (eltype .(ops)... )
331332 new {T, typeof(ops)} (ops)
332333 end
333334end
334335
335336function AddedOperator (ops:: AbstractSciMLOperator... )
336- sz = size (first (ops))
337- for op in ops[2 : end ]
338- @assert size (op)== sz " Dimension mismatch: cannot add operators of
339- sizes $(sz) , and $(size (op)) ."
340- end
341337 AddedOperator (ops)
342338end
343339
344340AddedOperator (L:: AbstractSciMLOperator ) = L
345341
342+ @generated function _check_AddedOperator_sizes (ops:: Tuple )
343+ ops_types = ops. parameters
344+ N = length (ops_types)
345+ sz_expr_list = ()
346+ sz_expr = :(sz = size (first (ops)))
347+ for i in 2 : N
348+ sz_expr_list = (sz_expr_list... , :(size (ops[$ i]) == sz))
349+ end
350+
351+ quote
352+ $ sz_expr
353+ @assert all (tuple ($ (sz_expr_list... ))) " Dimension mismatch: cannot add operators of different sizes."
354+ nothing
355+ end
356+ end
357+
346358# constructors
347359Base.:+ (A:: AbstractSciMLOperator , B:: AbstractMatrix ) = A + MatrixOperator (B)
348360Base.:+ (A:: AbstractMatrix , B:: AbstractSciMLOperator ) = MatrixOperator (A) + B
@@ -372,13 +384,15 @@ for op in (:+, :-)
372384 for LT in SCALINGCOMBINETYPES
373385 @eval function Base. $op (L:: $LT , λ:: $T )
374386 @assert issquare (L)
387+ iszero (λ) && return L
375388 N = size (L, 1 )
376389 Id = IdentityOperator (N)
377390 AddedOperator (L, $ op (λ) * Id)
378391 end
379392
380393 @eval function Base. $op (λ:: $T , L:: $LT )
381394 @assert issquare (L)
395+ iszero (λ) && return $ op (L)
382396 N = size (L, 1 )
383397 Id = IdentityOperator (N)
384398 AddedOperator (λ * Id, $ op (L))
@@ -440,24 +454,32 @@ function update_coefficients(L::AddedOperator, u, p, t)
440454 @reset L. ops = ops
441455end
442456
443- function update_coefficients! (L:: AddedOperator , u, p, t)
444- for op in L. ops
445- update_coefficients! (op, u, p, t)
446- end
457+ @generated function update_coefficients! (L:: AddedOperator , u, p, t)
458+ ops_types = L. parameters[2 ]. parameters
459+ N = length (ops_types)
460+ quote
461+ Base. @nexprs $ N i-> begin
462+ update_coefficients! (L. ops[i], u, p, t)
463+ end
447464
448- L
465+ nothing
466+ end
449467end
450468
451469getops (L:: AddedOperator ) = L. ops
452470islinear (L:: AddedOperator ) = all (islinear, getops (L))
453471Base. iszero (L:: AddedOperator ) = all (iszero, getops (L))
454472has_adjoint (L:: AddedOperator ) = all (has_adjoint, L. ops)
455473
456- function cache_internals (L:: AddedOperator , u:: AbstractVecOrMat )
457- for i in 1 : length (L. ops)
458- @reset L. ops[i] = cache_operator (L. ops[i], u)
474+ @generated function cache_internals (L:: AddedOperator , u:: AbstractVecOrMat )
475+ ops_types = L. parameters[2 ]. parameters
476+ N = length (ops_types)
477+ quote
478+ Base. @nexprs $ N i-> begin
479+ @reset L. ops[i] = cache_operator (L. ops[i], u)
480+ end
481+ L
459482 end
460- L
461483end
462484
463485getindex (L:: AddedOperator , i:: Int ) = sum (op -> op[i], L. ops)
@@ -467,26 +489,33 @@ function Base.:*(L::AddedOperator, u::AbstractVecOrMat)
467489 sum (op -> iszero (op) ? zero (u) : op * u, L. ops)
468490end
469491
470- function LinearAlgebra. mul! (v:: AbstractVecOrMat , L:: AddedOperator , u:: AbstractVecOrMat )
471- mul! (v, first (L. ops), u)
472- for op in L. ops[2 : end ]
473- iszero (op) && continue
474- mul! (v, op, u, true , true )
492+ @generated function LinearAlgebra. mul! (
493+ v:: AbstractVecOrMat , L:: AddedOperator , u:: AbstractVecOrMat )
494+ ops_types = L. parameters[2 ]. parameters
495+ N = length (ops_types)
496+ quote
497+ mul! (v, L. ops[1 ], u)
498+ Base. @nexprs $ (N - 1 ) i-> begin
499+ mul! (v, L. ops[i + 1 ], u, true , true )
500+ end
501+ v
475502 end
476- v
477503end
478504
479- function LinearAlgebra. mul! (v:: AbstractVecOrMat ,
505+ @generated function LinearAlgebra. mul! (v:: AbstractVecOrMat ,
480506 L:: AddedOperator ,
481507 u:: AbstractVecOrMat ,
482508 α,
483509 β)
484- lmul! (β, v)
485- for op in L. ops
486- iszero (op) && continue
487- mul! (v, op, u, α, true )
510+ ops_types = L. parameters[2 ]. parameters
511+ N = length (ops_types)
512+ quote
513+ lmul! (β, v)
514+ Base. @nexprs $ (N) i-> begin
515+ mul! (v, L. ops[i], u, α, true )
516+ end
517+ v
488518 end
489- v
490519end
491520
492521"""
0 commit comments