Skip to content

Commit 7180a9e

Browse files
Merge pull request #249 from albertomercurio/patch-1
Improve `mul!`, `AddedOperator`, and `update_coefficients!` to remove memory allocations
2 parents abaf736 + bb2c1d3 commit 7180a9e

File tree

8 files changed

+173
-45
lines changed

8 files changed

+173
-45
lines changed
Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
module SciMLOperatorsStaticArraysCoreExt
2-
3-
import SciMLOperators
4-
import StaticArraysCore
5-
6-
function Base.copyto!(L::SciMLOperators.MatrixOperator,
7-
rhs::Base.Broadcast.Broadcasted{<:StaticArraysCore.StaticArrayStyle})
8-
(copyto!(L.A, rhs); L)
9-
end
10-
11-
end #module
1+
module SciMLOperatorsStaticArraysCoreExt
2+
3+
import SciMLOperators
4+
import StaticArraysCore
5+
6+
function Base.copyto!(L::SciMLOperators.MatrixOperator,
7+
rhs::Base.Broadcast.Broadcasted{<:StaticArraysCore.StaticArrayStyle})
8+
(copyto!(L.A, rhs); L)
9+
end
10+
11+
end #module

src/basic.jl

Lines changed: 80 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ end
201201

202202
for T in SCALINGNUMBERTYPES
203203
@eval function ScaledOperator::$T, L::ScaledOperator)
204-
λ = ScalarOperator(λ) * L.λ
204+
λ = λ * L.λ
205205
ScaledOperator(λ, L.L)
206206
end
207207

@@ -250,7 +250,7 @@ function update_coefficients!(L::ScaledOperator, u, p, t)
250250
update_coefficients!(L.L, u, p, t)
251251
update_coefficients!(L.λ, u, p, t)
252252

253-
L
253+
nothing
254254
end
255255

256256
getops(L::ScaledOperator) = (L.λ, L.L)
@@ -288,13 +288,14 @@ end
288288
Base.:*(L::ScaledOperator, u::AbstractVecOrMat) = L.λ * (L.L * u)
289289
Base.:\(L::ScaledOperator, u::AbstractVecOrMat) = L.λ \ (L.L \ u)
290290

291-
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat)
291+
@inline function LinearAlgebra.mul!(
292+
v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat)
292293
iszero(L.λ) && return lmul!(false, v)
293294
a = convert(Number, L.λ)
294295
mul!(v, L.L, u, a, false)
295296
end
296297

297-
function LinearAlgebra.mul!(v::AbstractVecOrMat,
298+
@inline function LinearAlgebra.mul!(v::AbstractVecOrMat,
298299
L::ScaledOperator,
299300
u::AbstractVecOrMat,
300301
α,
@@ -326,22 +327,34 @@ struct AddedOperator{T,
326327

327328
function AddedOperator(ops)
328329
@assert !isempty(ops)
330+
_check_AddedOperator_sizes(ops)
329331
T = promote_type(eltype.(ops)...)
330332
new{T, typeof(ops)}(ops)
331333
end
332334
end
333335

334336
function AddedOperator(ops::AbstractSciMLOperator...)
335-
sz = size(first(ops))
336-
for op in ops[2:end]
337-
@assert size(op)==sz "Dimension mismatch: cannot add operators of
338-
sizes $(sz), and $(size(op))."
339-
end
340337
AddedOperator(ops)
341338
end
342339

343340
AddedOperator(L::AbstractSciMLOperator) = L
344341

342+
@generated function _check_AddedOperator_sizes(ops::Tuple)
343+
ops_types = ops.parameters
344+
N = length(ops_types)
345+
sz_expr_list = ()
346+
sz_expr = :(sz = size(first(ops)))
347+
for i in 2:N
348+
sz_expr_list = (sz_expr_list..., :(size(ops[$i]) == sz))
349+
end
350+
351+
quote
352+
$sz_expr
353+
@assert all(tuple($(sz_expr_list...))) "Dimension mismatch: cannot add operators of different sizes."
354+
nothing
355+
end
356+
end
357+
345358
# constructors
346359
Base.:+(A::AbstractSciMLOperator, B::AbstractMatrix) = A + MatrixOperator(B)
347360
Base.:+(A::AbstractMatrix, B::AbstractSciMLOperator) = MatrixOperator(A) + B
@@ -371,13 +384,15 @@ for op in (:+, :-)
371384
for LT in SCALINGCOMBINETYPES
372385
@eval function Base.$op(L::$LT, λ::$T)
373386
@assert issquare(L)
387+
iszero(λ) && return L
374388
N = size(L, 1)
375389
Id = IdentityOperator(N)
376390
AddedOperator(L, $op(λ) * Id)
377391
end
378392

379393
@eval function Base.$op::$T, L::$LT)
380394
@assert issquare(L)
395+
iszero(λ) && return $op(L)
381396
N = size(L, 1)
382397
Id = IdentityOperator(N)
383398
AddedOperator* Id, $op(L))
@@ -386,6 +401,23 @@ for op in (:+, :-)
386401
end
387402
end
388403

404+
for T in SCALINGNUMBERTYPES[2:end]
405+
@eval function Base.:*::$T, L::AddedOperator)
406+
ops = map(op -> λ * op, L.ops)
407+
AddedOperator(ops)
408+
end
409+
410+
@eval function Base.:*(L::AddedOperator, λ::$T)
411+
ops = map(op -> λ * op, L.ops)
412+
AddedOperator(ops)
413+
end
414+
415+
@eval function Base.:/(L::AddedOperator, λ::$T)
416+
ops = map(op -> op / λ, L.ops)
417+
AddedOperator(ops)
418+
end
419+
end
420+
389421
function Base.convert(::Type{AbstractMatrix}, L::AddedOperator)
390422
sum(op -> convert(AbstractMatrix, op), L.ops)
391423
end
@@ -422,16 +454,32 @@ function update_coefficients(L::AddedOperator, u, p, t)
422454
@reset L.ops = ops
423455
end
424456

457+
@generated function update_coefficients!(L::AddedOperator, u, p, t)
458+
ops_types = L.parameters[2].parameters
459+
N = length(ops_types)
460+
quote
461+
Base.@nexprs $N i->begin
462+
update_coefficients!(L.ops[i], u, p, t)
463+
end
464+
465+
nothing
466+
end
467+
end
468+
425469
getops(L::AddedOperator) = L.ops
426470
islinear(L::AddedOperator) = all(islinear, getops(L))
427471
Base.iszero(L::AddedOperator) = all(iszero, getops(L))
428472
has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops)
429473

430-
function cache_internals(L::AddedOperator, u::AbstractVecOrMat)
431-
for i in 1:length(L.ops)
432-
@reset L.ops[i] = cache_operator(L.ops[i], u)
474+
@generated function cache_internals(L::AddedOperator, u::AbstractVecOrMat)
475+
ops_types = L.parameters[2].parameters
476+
N = length(ops_types)
477+
quote
478+
Base.@nexprs $N i->begin
479+
@reset L.ops[i] = cache_operator(L.ops[i], u)
480+
end
481+
L
433482
end
434-
L
435483
end
436484

437485
getindex(L::AddedOperator, i::Int) = sum(op -> op[i], L.ops)
@@ -441,26 +489,33 @@ function Base.:*(L::AddedOperator, u::AbstractVecOrMat)
441489
sum(op -> iszero(op) ? zero(u) : op * u, L.ops)
442490
end
443491

444-
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::AddedOperator, u::AbstractVecOrMat)
445-
mul!(v, first(L.ops), u)
446-
for op in L.ops[2:end]
447-
iszero(op) && continue
448-
mul!(v, op, u, true, true)
492+
@generated function LinearAlgebra.mul!(
493+
v::AbstractVecOrMat, L::AddedOperator, u::AbstractVecOrMat)
494+
ops_types = L.parameters[2].parameters
495+
N = length(ops_types)
496+
quote
497+
mul!(v, L.ops[1], u)
498+
Base.@nexprs $(N - 1) i->begin
499+
mul!(v, L.ops[i + 1], u, true, true)
500+
end
501+
v
449502
end
450-
v
451503
end
452504

453-
function LinearAlgebra.mul!(v::AbstractVecOrMat,
505+
@generated function LinearAlgebra.mul!(v::AbstractVecOrMat,
454506
L::AddedOperator,
455507
u::AbstractVecOrMat,
456508
α,
457509
β)
458-
lmul!(β, v)
459-
for op in L.ops
460-
iszero(op) && continue
461-
mul!(v, op, u, α, true)
510+
ops_types = L.parameters[2].parameters
511+
N = length(ops_types)
512+
quote
513+
lmul!(β, v)
514+
Base.@nexprs $(N) i->begin
515+
mul!(v, L.ops[i], u, α, true)
516+
end
517+
v
462518
end
463-
v
464519
end
465520

466521
"""

src/batch.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ end
8888

8989
function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...)
9090
L.update_func!(L.diag, u, p, t; kwargs...)
91+
92+
nothing
9193
end
9294

9395
getops(L::BatchedDiagonalOperator) = (L.diag,)

src/func.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
382382
update_coefficients!(op, u, p, t; filtered_kwargs...)
383383
end
384384

385-
L
385+
nothing
386386
end
387387

388388
function iscached(L::FunctionOperator)

src/interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,11 @@ L * u
9898
"""
9999
update_coefficients!(L, u, p, t; kwargs...) = nothing
100100

101+
# We cannot use @generate because we don't know the type structure of L in advance
101102
function update_coefficients!(L::AbstractSciMLOperator, u, p, t; kwargs...)
102-
for op in getops(L)
103-
update_coefficients!(op, u, p, t; kwargs...)
104-
end
105-
L
103+
foreach(op -> update_coefficients!(op, u, p, t; kwargs...), getops(L))
104+
105+
nothing
106106
end
107107

108108
###

src/matrix.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ end
161161

162162
function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...)
163163
L.update_func!(L.A, u, p, t; kwargs...)
164+
165+
nothing
164166
end
165167

166168
# TODO - add tests for MatrixOperator indexing
@@ -194,10 +196,11 @@ end
194196
# operator application
195197
Base.:*(L::MatrixOperator, u::AbstractVecOrMat) = L.A * u
196198
Base.:\(L::MatrixOperator, u::AbstractVecOrMat) = L.A \ u
197-
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat)
199+
@inline function LinearAlgebra.mul!(
200+
v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat)
198201
mul!(v, L.A, u)
199202
end
200-
function LinearAlgebra.mul!(v::AbstractVecOrMat,
203+
@inline function LinearAlgebra.mul!(v::AbstractVecOrMat,
201204
L::MatrixOperator,
202205
u::AbstractVecOrMat,
203206
α,

src/scalar.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,12 @@ has_ldiv!(α::ScalarOperator) = has_ldiv(α)
191191

192192
function update_coefficients!(L::ScalarOperator, u, p, t; kwargs...)
193193
L.val = L.update_func(L.val, u, p, t; kwargs...)
194+
nothing
194195
end
195196

196197
function update_coefficients(L::ScalarOperator, u, p, t; kwargs...)
197-
@reset L.val = L.update_func(L.val, u, p, t; kwargs...)
198+
update_coefficients!(L, u, p, t; kwargs...)
199+
L
198200
end
199201

200202
"""
@@ -313,6 +315,26 @@ for op in (:*, :∘)
313315
end
314316
end
315317

318+
# Different methods for constant ScalarOperators
319+
for T in SCALINGNUMBERTYPES[2:end]
320+
@eval function Base.:*::ScalarOperator, x::$T)
321+
if isconstant(α)
322+
T2 = promote_type($T, eltype(α))
323+
return ScalarOperator(convert(T2, α) * x, α.update_func)
324+
else
325+
return ComposedScalarOperator(α, ScalarOperator(x))
326+
end
327+
end
328+
@eval function Base.:*(x::$T, α::ScalarOperator)
329+
if isconstant(α)
330+
T2 = promote_type($T, eltype(α))
331+
return ScalarOperator(convert(T2, α) * x, α.update_func)
332+
else
333+
return ComposedScalarOperator(ScalarOperator(x), α)
334+
end
335+
end
336+
end
337+
316338
function Base.convert(T::Type{<:Number}, α::ComposedScalarOperator)
317339
iszero(α) && return zero(T)
318340
prod(convert.(T, α.ops))

test/basic.jl

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SciMLOperators, LinearAlgebra
1+
using SciMLOperators, LinearAlgebra, SparseArrays
22
using Random
33

44
using SciMLOperators: IdentityOperator,
@@ -138,6 +138,13 @@ end
138138
@test ldiv!(op, u) * D) \ v
139139
end
140140

141+
function apply_op!(H, du, u, p, t)
142+
H(du, u, p, t)
143+
return nothing
144+
end
145+
146+
test_apply_noalloc(H, du, u, p, t) = @test (@allocations apply_op!(H, du, u, p, t)) == 0
147+
141148
@testset "AddedOperator" begin
142149
A = rand(N, N) |> MatrixOperator
143150
B = rand(N, N) |> MatrixOperator
@@ -184,6 +191,45 @@ end
184191
for op in L.ops
185192
@test !isa(op, AddedOperator)
186193
end
194+
195+
# Allocations Tests
196+
197+
@allocations apply_op!(op, v, u, (), 1.0) # warmup
198+
test_apply_noalloc(op, v, u, (), 1.0)
199+
200+
## Time-Dependent Coefficients
201+
202+
for T in (Float32, Float64, ComplexF32, ComplexF64)
203+
N = 100
204+
A1_sparse = MatrixOperator(sprand(T, N, N, 5 / N))
205+
A2_sparse = MatrixOperator(sprand(T, N, N, 5 / N))
206+
A3_sparse = MatrixOperator(sprand(T, N, N, 5 / N))
207+
208+
A1_dense = MatrixOperator(rand(T, N, N))
209+
A2_dense = MatrixOperator(rand(T, N, N))
210+
A3_dense = MatrixOperator(rand(T, N, N))
211+
212+
coeff1(a, u, p, t) = sin(p.ω * t)
213+
coeff2(a, u, p, t) = cos(p.ω * t)
214+
coeff3(a, u, p, t) = sin(p.ω * t) * cos(p.ω * t)
215+
216+
c1 = ScalarOperator(rand(T), coeff1)
217+
c2 = ScalarOperator(rand(T), coeff2)
218+
c3 = ScalarOperator(rand(T), coeff3)
219+
220+
H_sparse = c1 * A1_sparse + c2 * A2_sparse + c3 * A3_sparse
221+
H_dense = c1 * A1_dense + c2 * A2_dense + c3 * A3_dense
222+
223+
u = rand(T, N)
224+
du = similar(u)
225+
p == 0.1,)
226+
t = 0.1
227+
228+
@allocations apply_op!(H_sparse, du, u, p, t) # warmup
229+
@allocations apply_op!(H_dense, du, u, p, t) # warmup
230+
test_apply_noalloc(H_sparse, du, u, p, t)
231+
test_apply_noalloc(H_dense, du, u, p, t)
232+
end
187233
end
188234

189235
@testset "ComposedOperator" begin

0 commit comments

Comments
 (0)