Skip to content

Commit 401c204

Browse files
Merge pull request #43 from vpuri3/cache
ensure component operators get cached
2 parents ae57bd5 + cc10403 commit 401c204

File tree

5 files changed

+120
-38
lines changed

5 files changed

+120
-38
lines changed

src/SciMLOperators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ import Lazy: @forward
1111
import Setfield: @set!
1212

1313
# overload
14-
import Base: +, -, *, /, \, , ==
15-
import Base: conj, one, iszero, inv, adjoint, transpose, size, convert, Matrix
14+
import Base: +, -, *, /, \, , ==, one, zero
15+
import Base: conj, iszero, inv, adjoint, transpose, size, convert, Matrix
1616
import LinearAlgebra: mul!, ldiv!, lmul!, rmul!, factorize, exp, Diagonal
1717
import SparseArrays: sparse
1818

src/basic.jl

Lines changed: 65 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ function Base.one(A::AbstractSciMLOperator)
1212
IdentityOperator{N}()
1313
end
1414

15-
# TODO - sparse diagonal
1615
Base.convert(::Type{AbstractMatrix}, ::IdentityOperator{N}) where{N} = Diagonal(ones(Bool, N))
1716

1817
# traits
@@ -187,8 +186,10 @@ function Base.adjoint(α::ScalarOperator) # TODO - test
187186
ScalarOperator(val; update_func=update_func)
188187
end
189188
Base.transpose::ScalarOperator) = α
189+
Base.one(::Type{AbstractSciMLOperator}) = ScalarOperator(true)
190+
Base.zero(::Type{AbstractSciMLOperator}) = ScalarOperator(false)
190191

191-
getops::ScalarOperator) =.val)
192+
getops::ScalarOperator) =.val,)
192193
islinear(L::ScalarOperator) = true
193194
issquare(L::ScalarOperator) = true
194195
isconstant::ScalarOperator) = α.update_func == DEFAULT_UPDATE_FUNC
@@ -249,9 +250,11 @@ struct ScaledOperator{T,
249250
L::LType
250251
cache::T
251252

252-
function ScaledOperator::ScalarOperator, L::AbstractSciMLOperator)
253-
T = promote_type(eltype.((λ, L))...)
254-
cache = zero(T)
253+
function ScaledOperator::ScalarOperator{Tλ},
254+
L::AbstractSciMLOperator{TL},
255+
cache = zero(promote_type(Tλ,TL))
256+
) where{Tλ,TL}
257+
T = promote_type(Tλ, TL)
255258
new{T,typeof(λ),typeof(L)}(λ, L, cache)
256259
end
257260
end
@@ -297,7 +300,7 @@ for op in (
297300
end
298301
LinearAlgebra.opnorm(L::ScaledOperator, p::Real=2) = abs(L.λ) * opnorm(L.L, p)
299302

300-
getops(L::ScaledOperator) = (L.λ, L.L)
303+
getops(L::ScaledOperator) = (L.λ, L.L,)
301304
islinear(L::ScaledOperator) = all(islinear, L.ops)
302305
isconstant(L::ScaledOperator) = isconstant(L.L) & isconstant(L.λ)
303306
Base.iszero(L::ScaledOperator) = iszero(L.L) | iszero(L.λ)
@@ -306,6 +309,12 @@ has_mul!(L::ScaledOperator) = has_mul!(L.L)
306309
has_ldiv(L::ScaledOperator) = has_ldiv(L.L) & !iszero(L.λ)
307310
has_ldiv!(L::ScaledOperator) = has_ldiv!(L.L) & !iszero(L.λ)
308311

312+
function cache_internals(L::ScaledOperator, u::AbstractVector)
313+
@set! L.L = cache_operator(L.L, u)
314+
@set! L.λ = cache_operator(L.λ, u)
315+
L
316+
end
317+
309318
# getindex
310319
Base.getindex(L::ScaledOperator, i::Int) = L.coeff * L.op[i]
311320
Base.getindex(L::ScaledOperator, I::Vararg{Int, N}) where {N} = L.λ * L.L[I...]
@@ -348,7 +357,7 @@ function LinearAlgebra.ldiv!(L::ScaledOperator, u::AbstractVector)
348357
end
349358

350359
"""
351-
Lazy operator addition (A + B)
360+
Lazy operator addition
352361
353362
(A1 + A2 + A3...)u = A1*u + A2*u + A3*u ....
354363
"""
@@ -357,17 +366,22 @@ struct AddedOperator{T,
357366
} <: AbstractSciMLOperator{T}
358367
ops::O
359368

360-
function AddedOperator(ops...)
361-
sz = size(first(ops))
362-
for op in ops[2:end]
363-
@assert size(op) == sz "Size mismatich in operators $ops"
364-
end
365-
369+
function AddedOperator(ops)
366370
T = promote_type(eltype.(ops)...)
367371
new{T,typeof(ops)}(ops)
368372
end
369373
end
370374

375+
function AddedOperator(ops::AbstractSciMLOperator...)
376+
sz = size(first(ops))
377+
for op in ops[2:end]
378+
@assert size(op) == sz "Size mismatich in operators $ops"
379+
end
380+
AddedOperator(ops)
381+
end
382+
383+
AddedOperator(L::AbstractSciMLOperator) = L
384+
371385
# constructors
372386
Base.:+(ops::AbstractSciMLOperator...) = AddedOperator(ops...)
373387
Base.:-(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = AddedOperator(A, -B)
@@ -423,6 +437,13 @@ getops(L::AddedOperator) = L.ops
423437
Base.iszero(L::AddedOperator) = all(iszero, getops(L))
424438
has_adjoint(L::AddedOperator) = all(has_adjoint, L.ops)
425439

440+
function cache_internals(L::AddedOperator, u::AbstractVector)
441+
for i=1:length(L.ops)
442+
@set! L.ops[i] = cache_operator(L.ops[i], u)
443+
end
444+
L
445+
end
446+
426447
getindex(L::AddedOperator, i::Int) = sum(op -> op[i], L.ops)
427448
getindex(L::AddedOperator, I::Vararg{Int, N}) where {N} = sum(op -> op[I...], L.ops)
428449

@@ -532,30 +553,35 @@ end
532553
Base.:*(L::ComposedOperator, u::AbstractVector) = foldl((acc, op) -> op * acc, reverse(L.ops); init=u)
533554
Base.:\(L::ComposedOperator, u::AbstractVector) = foldl((acc, op) -> op \ acc, L.ops; init=u)
534555

535-
function cache_operator(L::ComposedOperator, u::AbstractVector)
536-
# for 3 arg mul!
537-
# Tuple of N-1 cache vectors. cache[N-1] = op[N] * u and so on
538-
vec = u
539-
c3 = ()
556+
function cache_self(L::ComposedOperator, u::AbstractVector)
557+
vec = similar(u)
558+
cache = (vec,)
540559
for i in reverse(2:length(L.ops))
541-
vec = L.ops[i] * vec
542-
c3 = (c3..., vec)
560+
vec = L.ops[i] * vec
561+
cache = (vec, cache...)
543562
end
544563

545-
# for 5 arg mul!
546-
c5 = similar(u)
564+
@set! L.cache = cache
565+
L
566+
end
547567

548-
cache = (;c3=c3, c5=c5)
568+
function cache_internals(L::ComposedOperator, u::AbstractVector)
569+
if !(L.isset)
570+
L = cache_self(L, u)
571+
end
572+
573+
vecs = L.cache
574+
for i in reverse(1:length(L.ops))
575+
@set! L.ops[i] = cache_operator(L.ops[i], vecs[i])
576+
end
549577

550-
@set! L.cache = cache
551578
L
552579
end
553580

554581
function LinearAlgebra.mul!(v::AbstractVector, L::ComposedOperator, u::AbstractVector)
555582
@assert L.isset "cache needs to be set up to use LinearAlgebra.mul!"
556583

557-
cache = L.cache.c3
558-
vecs = (v, cache..., u)
584+
vecs = (v, L.cache[1:end-1]..., u)
559585
for i in reverse(1:length(L.ops))
560586
mul!(vecs[i], L.ops[i], vecs[i+1])
561587
end
@@ -565,7 +591,7 @@ end
565591
function LinearAlgebra.mul!(v::AbstractVector, L::ComposedOperator, u::AbstractVector, α, β)
566592
@assert L.isset "cache needs to be set up to use LinearAlgebra.mul!"
567593

568-
cache = L.cache.c5
594+
cache = L.cache[end]
569595
copy!(cache, v)
570596

571597
mul!(v, L, u)
@@ -576,8 +602,7 @@ end
576602
function LinearAlgebra.ldiv!(v::AbstractVector, L::ComposedOperator, u::AbstractVector)
577603
@assert L.isset "cache needs to be set up to use 3 arg LinearAlgebra.ldiv!"
578604

579-
cache = L.cache.c3
580-
vecs = (u, reverse(cache)..., v)
605+
vecs = (u, reverse(L.cache[1:end-1])..., v)
581606
for i in 1:length(L.ops)
582607
ldiv!(vecs[i+1], L.ops[i], vecs[i])
583608
end
@@ -645,6 +670,11 @@ for (op, LType, VType) in (
645670
has_ldiv!,
646671
)
647672

673+
@eval function cache_internals(L::$LType, u::AbstractVector)
674+
@set! L.L = cache_operator(L.L, _reshape(u, size(L,1)))
675+
L
676+
end
677+
648678
# oeprator application
649679
@eval Base.:*(u::$VType, L::$LType) = $op(L.L * u.parent)
650680
@eval Base.:/(u::$VType, L::$LType) = $op(L.L \ u.parent)
@@ -724,12 +754,17 @@ has_ldiv!(L::InvertedOperator) = has_mul!(L.L)
724754
Base.:*(L::InvertedOperator, u::AbstractVector) = L.L \ u
725755
Base.:\(L::InvertedOperator, u::AbstractVector) = L.L * u
726756

727-
function cache_operator(L::InvertedOperator, u::AbstractVector)
757+
function cache_self(L::InvertedOperator, u::AbstractVector)
728758
cache = similar(u)
729759
@set! L.cache = cache
730760
L
731761
end
732762

763+
function cache_internals(L::InvertedOperator, u::AbstractVector)
764+
@set! L.L = cache_operator(L.L, u)
765+
L
766+
end
767+
733768
function LinearAlgebra.mul!(v::AbstractVector, L::InvertedOperator, u::AbstractVector)
734769
ldiv!(v, L.L, u)
735770
end

src/interface.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,24 @@ end
2929
(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u)
3030
(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u))
3131

32+
###
33+
# caching interface
34+
###
35+
3236
"""
3337
Allocate caches for a SciMLOperator for fast evaluation
3438
3539
arguments:
3640
L :: AbstractSciMLOperator
3741
u :: AbstractVector argument to L
3842
"""
43+
cache_operator(L, u) = L
44+
cache_self(L, u) = L
45+
cache_internals(L, u) = L
46+
3947
function cache_operator(L::AbstractSciMLOperator, u::AbstractVector)
48+
L = cache_self(L, u)
49+
L = cache_internals(L, u)
4050
L
4151
end
4252

src/sciml.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ LinearAlgebra.issymmetric(L::FunctionOperator) = L.traits.issymmetric
457457
LinearAlgebra.ishermitian(L::FunctionOperator) = L.traits.ishermitian
458458
LinearAlgebra.isposdef(L::FunctionOperator) = L.traits.isposdef
459459

460+
getops(::FunctionOperator) = ()
460461
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
461462
has_mul(L::FunctionOperator{iip}) where{iip} = !iip
462463
has_mul!(L::FunctionOperator{iip}) where{iip} = iip
@@ -467,7 +468,7 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin
467468
Base.:*(L::FunctionOperator, u::AbstractVector) = L.op(u, L.p, L.t)
468469
Base.:\(L::FunctionOperator, u::AbstractVector) = L.op_inverse(u, L.p, L.t)
469470

470-
function cache_operator(L::FunctionOperator, u::AbstractVector)
471+
function cache_self(L::FunctionOperator, u::AbstractVector)
471472
@set! L.cache = similar(u)
472473
L
473474
end
@@ -584,16 +585,28 @@ function Base.:\(L::TensorProductOperator, u::AbstractVector)
584585
_vec(V)
585586
end
586587

587-
function cache_operator(L::TensorProductOperator, u::AbstractVector)
588+
function cache_self(L::TensorProductOperator, u::AbstractVector)
588589
sz = (size(L.inner, 2), size(L.outer, 2))
589590
U = _reshape(u, sz)
590591
cache = L.inner * U
591592

592593
@set! L.cache = cache
594+
L
595+
end
596+
597+
function cache_internals(L::TensorProductOperator, u::AbstractVector)
598+
if !(L.isset)
599+
L = cache_self(L, u)
600+
end
601+
602+
sz = (size(L.inner, 2), size(L.outer, 2))
603+
U = _reshape(u, sz)
593604

594-
L.inner isa AbstractSciMLOperator && @set! L.inner = cache_operator(L.inner)
595-
L.outer isa AbstractSciMLOperator && @set! L.outer = cache_operator(L.outer)
605+
uinner = U
606+
uouter = transpose(L.cache)
596607

608+
@set! L.inner = cache_operator(L.inner, uinner)
609+
@set! L.outer = cache_operator(L.outer, uouter)
597610
L
598611
end
599612

test/sciml.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using SciMLOperators, LinearAlgebra
22
using Random
33

4-
using SciMLOperators: InvertibleOperator,
4+
using SciMLOperators: AbstractSciMLOperator, InvertibleOperator,
55

66
Random.seed!(0)
77
N = 8
@@ -197,6 +197,30 @@ end
197197
end
198198

199199
@testset "Operator Algebra" begin
200-
# try out array arithmatic
200+
N2 = N*N
201+
A = rand(N,N)
202+
B = rand(N,N)
203+
C = rand(N,N)
204+
D = rand(N,N)
205+
206+
u = rand(N2)
207+
α = rand()
208+
β = rand()
209+
210+
T1 = (A, B)
211+
T2 = (C, D)
212+
213+
D1 = DiagonalOperator(rand(N2))
214+
D2 = DiagonalOperator(rand(N2))
215+
216+
TT = AbstractSciMLOperator[T1, T2]
217+
DD = Diagonal(AbstractSciMLOperator[D1, D2])
218+
219+
op = TT' * DD * TT
220+
221+
op = cache_operator(op, u)
222+
223+
v=rand(N2); @test mul!(v, op, u) op * u
224+
v=rand(N2); w=copy(v); @test mul!(v, op, u, α, β) α*(op * u) + β * w
201225
end
202226
#

0 commit comments

Comments
 (0)