Skip to content

Commit 8b9feb2

Browse files
authored
Merge pull request #140 from vpuri3/iscached
create and export trait `iscached(L)`
2 parents f0c96d8 + dbc8634 commit 8b9feb2

File tree

9 files changed

+206
-150
lines changed

9 files changed

+206
-150
lines changed

src/SciMLOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export ScalarOperator,
5151
export update_coefficients!,
5252
update_coefficients,
5353

54+
iscached,
5455
cache_operator,
5556

5657
issquare,

src/basic.jl

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,8 @@ struct ComposedOperator{T,O,C} <: AbstractSciMLOperator{T}
424424
ops::O
425425
""" cache for 3 and 5 argument mul! """
426426
cache::C
427-
""" is cache set """
428-
isset::Bool
429427

430-
function ComposedOperator(ops, cache, isset::Bool)
428+
function ComposedOperator(ops, cache)
431429
@assert !isempty(ops)
432430
for i in reverse(2:length(ops))
433431
opcurr = ops[i]
@@ -437,14 +435,12 @@ struct ComposedOperator{T,O,C} <: AbstractSciMLOperator{T}
437435
end
438436

439437
T = promote_type(eltype.(ops)...)
440-
isset = cache !== nothing
441-
new{T,typeof(ops),typeof(cache)}(ops, cache, isset)
438+
new{T,typeof(ops),typeof(cache)}(ops, cache)
442439
end
443440
end
444441

445442
function ComposedOperator(ops::AbstractSciMLOperator...; cache = nothing)
446-
isset = cache !== nothing
447-
ComposedOperator(ops, cache, isset)
443+
ComposedOperator(ops, cache)
448444
end
449445

450446
# constructors
@@ -505,7 +501,7 @@ for op in (
505501
)
506502
@eval Base.$op(L::ComposedOperator) = ComposedOperator(
507503
$op.(reverse(L.ops))...;
508-
cache=L.isset ? reverse(L.cache) : nothing,
504+
cache=iscached(L) ? reverse(L.cache) : nothing,
509505
)
510506
end
511507
Base.conj(L::ComposedOperator) = ComposedOperator(conj.(L.ops); cache=L.cache)
@@ -582,7 +578,7 @@ function cache_self(L::ComposedOperator, u::AbstractVecOrMat)
582578
end
583579

584580
function cache_internals(L::ComposedOperator, u::AbstractVecOrMat)
585-
if !(L.isset)
581+
if !iscached(L)
586582
L = cache_self(L, u)
587583
end
588584

@@ -595,7 +591,7 @@ function cache_internals(L::ComposedOperator, u::AbstractVecOrMat)
595591
end
596592

597593
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat)
598-
@assert L.isset "cache needs to be set up for operator of type $(typeof(L)).
594+
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
599595
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
600596

601597
vecs = (v, L.cache[1:end-1]..., u)
@@ -606,7 +602,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::Abstrac
606602
end
607603

608604
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat, α, β)
609-
@assert L.isset "cache needs to be set up for operator of type $(typeof(L)).
605+
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
610606
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
611607

612608
cache = L.cache[end]
@@ -618,7 +614,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ComposedOperator, u::Abstrac
618614
end
619615

620616
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::ComposedOperator, u::AbstractVecOrMat)
621-
@assert L.isset "cache needs to be set up for operator of type $(typeof(L)).
617+
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
622618
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
623619

624620
vecs = (u, reverse(L.cache[1:end-1])..., v)
@@ -642,17 +638,14 @@ end
642638
struct InvertedOperator{T, LType, C} <: AbstractSciMLOperator{T}
643639
L::LType
644640
cache::C
645-
isset::Bool
646641

647-
function InvertedOperator(L::AbstractSciMLOperator{T}, cache, isset) where{T}
648-
isset = cache !== nothing
649-
new{T,typeof(L),typeof(cache)}(L, cache, isset)
642+
function InvertedOperator(L::AbstractSciMLOperator{T}, cache) where{T}
643+
new{T,typeof(L),typeof(cache)}(L, cache)
650644
end
651645
end
652646

653647
function InvertedOperator(L::AbstractSciMLOperator{T}; cache=nothing) where{T}
654-
isset = cache !== nothing
655-
InvertedOperator(L, cache, isset)
648+
InvertedOperator(L, cache)
656649
end
657650

658651
Base.inv(L::AbstractSciMLOperator) = InvertedOperator(L)
@@ -663,8 +656,8 @@ Base.:/(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = A * inv(B)
663656
Base.convert(::Type{AbstractMatrix}, L::InvertedOperator) = inv(convert(AbstractMatrix, L.L))
664657

665658
Base.size(L::InvertedOperator) = size(L.L) |> reverse
666-
Base.transpose(L::InvertedOperator) = InvertedOperator(transpose(L.L); cache = L.isset ? L.cache' : nothing)
667-
Base.adjoint(L::InvertedOperator) = InvertedOperator(adjoint(L.L); cache = L.isset ? L.cache' : nothing)
659+
Base.transpose(L::InvertedOperator) = InvertedOperator(transpose(L.L); cache = iscached(L) ? L.cache' : nothing)
660+
Base.adjoint(L::InvertedOperator) = InvertedOperator(adjoint(L.L); cache = iscached(L) ? L.cache' : nothing)
668661
Base.conj(L::InvertedOperator) = InvertedOperator(conj(L.L); cache=L.cache)
669662

670663
getops(L::InvertedOperator) = (L.L,)
@@ -706,7 +699,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertedOperator, u::Abstrac
706699
end
707700

708701
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::InvertedOperator, u::AbstractVecOrMat, α, β)
709-
@assert L.isset "cache needs to be set up for operator of type $(typeof(L)).
702+
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
710703
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
711704

712705
copy!(L.cache, v)
@@ -720,7 +713,7 @@ function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertedOperator, u::Abstra
720713
end
721714

722715
function LinearAlgebra.ldiv!(L::InvertedOperator, u::AbstractVecOrMat)
723-
@assert L.isset "cache needs to be set up for operator of type $(typeof(L)).
716+
@assert iscached(L) "cache needs to be set up for operator of type $(typeof(L)).
724717
set up cache by calling cache_operator(L::AbstractSciMLOperator, u::AbstractArray)"
725718

726719
copy!(L.cache, u)

src/func.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
1717
p::P
1818
""" Time """
1919
t::Tt
20-
""" Is cache set? """
21-
isset::Bool
2220
""" Cache """
2321
cache::C
2422

@@ -30,16 +28,13 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
3028
traits,
3129
p,
3230
t,
33-
isset,
3431
cache
3532
)
3633

3734
iip = traits.isinplace
3835
oop = traits.outofplace
3936
T = traits.T
4037

41-
isset = cache !== nothing
42-
4338
new{
4439
iip,
4540
oop,
@@ -60,7 +55,6 @@ mutable struct FunctionOperator{iip,oop,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: Abst
6055
traits,
6156
p,
6257
t,
63-
isset,
6458
cache,
6559
)
6660
end
@@ -164,7 +158,6 @@ function FunctionOperator(op,
164158
)
165159

166160
cache = zero.((input, output))
167-
isset = true
168161

169162
FunctionOperator(
170163
op,
@@ -174,7 +167,6 @@ function FunctionOperator(op,
174167
traits,
175168
p,
176169
t,
177-
isset,
178170
cache,
179171
)
180172
end
@@ -209,8 +201,7 @@ function Base.adjoint(L::FunctionOperator)
209201
p = L.p
210202
t = L.t
211203

212-
isset = L.isset
213-
cache = if isset
204+
cache = if iscached(L)
214205
cache = reverse(L.cache)
215206
else
216207
nothing
@@ -223,7 +214,6 @@ function Base.adjoint(L::FunctionOperator)
223214
traits,
224215
p,
225216
t,
226-
isset,
227217
cache,
228218
)
229219
end
@@ -253,8 +243,7 @@ function Base.inv(L::FunctionOperator)
253243
p = L.p
254244
t = L.t
255245

256-
isset = L.cache !== nothing
257-
cache = if isset
246+
cache = if iscached(L)
258247
cache = reverse(L.cache)
259248
else
260249
nothing
@@ -267,7 +256,6 @@ function Base.inv(L::FunctionOperator)
267256
traits,
268257
p,
269258
t,
270-
isset,
271259
cache,
272260
)
273261
end

src/interface.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@ end
3333
# caching interface
3434
###
3535

36+
function iscached(L::AbstractSciMLOperator)
37+
has_cache = hasfield(typeof(L), :cache) # TODO - confirm this is static
38+
isset = has_cache ? L.cache !== nothing : true
39+
40+
return isset & all(iscached, getops(L))
41+
end
42+
43+
iscached(::Union{
44+
# LinearAlgebra
45+
AbstractMatrix,
46+
UniformScaling,
47+
Factorization,
48+
49+
# Base
50+
Number,
51+
52+
}
53+
) = true
54+
3655
"""
3756
Allocate caches for a SciMLOperator for fast evaluation
3857

0 commit comments

Comments
 (0)