Skip to content

Commit 934ae5d

Browse files
Implement generated functions
1 parent 3fb8a9e commit 934ae5d

File tree

6 files changed

+68
-33
lines changed

6 files changed

+68
-33
lines changed

src/basic.jl

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ function update_coefficients!(L::ScaledOperator, u, p, t)
250250
update_coefficients!(L.L, u, p, t)
251251
update_coefficients!(L.λ, u, p, t)
252252

253-
L
253+
nothing
254254
end
255255

256256
getops(L::ScaledOperator) = (L.λ, L.L)
@@ -327,22 +327,34 @@ struct AddedOperator{T,
327327

328328
function AddedOperator(ops)
329329
@assert !isempty(ops)
330+
_check_AddedOperator_sizes(ops)
330331
T = promote_type(eltype.(ops)...)
331332
new{T, typeof(ops)}(ops)
332333
end
333334
end
334335

335336
function AddedOperator(ops::AbstractSciMLOperator...)
336-
sz = size(first(ops))
337-
for op in ops[2:end]
338-
@assert size(op)==sz "Dimension mismatch: cannot add operators of
339-
sizes $(sz), and $(size(op))."
340-
end
341337
AddedOperator(ops)
342338
end
343339

344340
AddedOperator(L::AbstractSciMLOperator) = L
345341

342+
@generated function _check_AddedOperator_sizes(ops::Tuple)
343+
ops_types = ops.parameters
344+
N = length(ops_types)
345+
sz_expr_list = ()
346+
sz_expr = :(sz = size(first(ops)))
347+
for i in 2:N
348+
sz_expr_list = (sz_expr_list..., :(size(ops[$i]) == sz))
349+
end
350+
351+
quote
352+
$sz_expr
353+
@assert all(tuple($(sz_expr_list...))) "Dimension mismatch: cannot add operators of different sizes."
354+
nothing
355+
end
356+
end
357+
346358
# constructors
347359
Base.:+(A::AbstractSciMLOperator, B::AbstractMatrix) = A + MatrixOperator(B)
348360
Base.:+(A::AbstractMatrix, B::AbstractSciMLOperator) = MatrixOperator(A) + B
@@ -372,13 +384,15 @@ for op in (:+, :-)
372384
for LT in SCALINGCOMBINETYPES
373385
@eval function Base.$op(L::$LT, λ::$T)
374386
@assert issquare(L)
387+
iszero(λ) && return L
375388
N = size(L, 1)
376389
Id = IdentityOperator(N)
377390
AddedOperator(L, $op(λ) * Id)
378391
end
379392

380393
@eval function Base.$op::$T, L::$LT)
381394
@assert issquare(L)
395+
iszero(λ) && return $op(L)
382396
N = size(L, 1)
383397
Id = IdentityOperator(N)
384398
AddedOperator* Id, $op(L))
@@ -440,24 +454,32 @@ function update_coefficients(L::AddedOperator, u, p, t)
440454
@reset L.ops = ops
441455
end
442456

443-
function update_coefficients!(L::AddedOperator, u, p, t)
444-
for op in L.ops
445-
update_coefficients!(op, u, p, t)
446-
end
457+
@generated function update_coefficients!(L::AddedOperator, u, p, t)
458+
ops_types = L.parameters[2].parameters
459+
N = length(ops_types)
460+
quote
461+
Base.@nexprs $N i->begin
462+
update_coefficients!(L.ops[i], u, p, t)
463+
end
447464

448-
L
465+
nothing
466+
end
449467
end
450468

451469
getops(L::AddedOperator) = L.ops
452470
islinear(L::AddedOperator) = all(islinear, getops(L))
453471
Base.iszero(L::AddedOperator) = all(iszero, getops(L))
454472
has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops)
455473

456-
function cache_internals(L::AddedOperator, u::AbstractVecOrMat)
457-
for i in 1:length(L.ops)
458-
@reset L.ops[i] = cache_operator(L.ops[i], u)
474+
@generated function cache_internals(L::AddedOperator, u::AbstractVecOrMat)
475+
ops_types = L.parameters[2].parameters
476+
N = length(ops_types)
477+
quote
478+
Base.@nexprs $N i->begin
479+
@reset L.ops[i] = cache_operator(L.ops[i], u)
480+
end
481+
L
459482
end
460-
L
461483
end
462484

463485
getindex(L::AddedOperator, i::Int) = sum(op -> op[i], L.ops)
@@ -467,26 +489,33 @@ function Base.:*(L::AddedOperator, u::AbstractVecOrMat)
467489
sum(op -> iszero(op) ? zero(u) : op * u, L.ops)
468490
end
469491

470-
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AddedOperator, u::AbstractVecOrMat)
471-
mul!(v, first(L.ops), u)
472-
for op in L.ops[2:end]
473-
iszero(op) && continue
474-
mul!(v, op, u, true, true)
492+
@generated function LinearAlgebra.mul!(
493+
v::AbstractVecOrMat, L::AddedOperator, u::AbstractVecOrMat)
494+
ops_types = L.parameters[2].parameters
495+
N = length(ops_types)
496+
quote
497+
mul!(v, L.ops[1], u)
498+
Base.@nexprs $(N - 1) i->begin
499+
mul!(v, L.ops[i + 1], u, true, true)
500+
end
501+
v
475502
end
476-
v
477503
end
478504

479-
function LinearAlgebra.mul!(v::AbstractVecOrMat,
505+
@generated function LinearAlgebra.mul!(v::AbstractVecOrMat,
480506
L::AddedOperator,
481507
u::AbstractVecOrMat,
482508
α,
483509
β)
484-
lmul!(β, v)
485-
for op in L.ops
486-
iszero(op) && continue
487-
mul!(v, op, u, α, true)
510+
ops_types = L.parameters[2].parameters
511+
N = length(ops_types)
512+
quote
513+
lmul!(β, v)
514+
Base.@nexprs $(N) i->begin
515+
mul!(v, L.ops[i], u, α, true)
516+
end
517+
v
488518
end
489-
v
490519
end
491520

492521
"""

src/batch.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ end
8888

8989
function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...)
9090
L.update_func!(L.diag, u, p, t; kwargs...)
91+
92+
nothing
9193
end
9294

9395
getops(L::BatchedDiagonalOperator) = (L.diag,)

src/func.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
382382
update_coefficients!(op, u, p, t; filtered_kwargs...)
383383
end
384384

385-
L
385+
nothing
386386
end
387387

388388
function iscached(L::FunctionOperator)

src/interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ L * u
9898
"""
9999
update_coefficients!(L, u, p, t; kwargs...) = nothing
100100

101+
# We cannot use @generate because we don't know the type structure of L in advance
101102
function update_coefficients!(L::AbstractSciMLOperator, u, p, t; kwargs...)
102-
for op in getops(L)
103-
update_coefficients!(op, u, p, t; kwargs...)
104-
end
105-
L
103+
foreach(op -> update_coefficients!(op, u, p, t; kwargs...), getops(L))
104+
105+
nothing
106106
end
107107

108108
###

src/matrix.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ end
161161

162162
function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...)
163163
L.update_func!(L.A, u, p, t; kwargs...)
164+
165+
nothing
164166
end
165167

166168
# TODO - add tests for MatrixOperator indexing

src/scalar.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,12 @@ has_ldiv!(α::ScalarOperator) = has_ldiv(α)
191191

192192
function update_coefficients!(L::ScalarOperator, u, p, t; kwargs...)
193193
L.val = L.update_func(L.val, u, p, t; kwargs...)
194+
nothing
194195
end
195196

196197
function update_coefficients(L::ScalarOperator, u, p, t; kwargs...)
197-
@reset L.val = L.update_func(L.val, u, p, t; kwargs...)
198+
update_coefficients!(L, u, p, t; kwargs...)
199+
L
198200
end
199201

200202
"""

0 commit comments

Comments
 (0)