Skip to content

Commit 16cb419

Browse files
Merge pull request #150 from vpuri3/func_cache
Update FunctionOperator
2 parents 854161d + 66681bb commit 16cb419

File tree

3 files changed

+74
-18
lines changed

3 files changed

+74
-18
lines changed

src/func.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ function FunctionOperator(op,
9797
p=nothing,
9898
t::Union{Number,Nothing}=nothing,
9999

100+
ifcache::Bool = true,
101+
100102
# traits
101103
islinear::Bool = false,
102104

@@ -157,18 +159,20 @@ function FunctionOperator(op,
157159
size = sz,
158160
)
159161

160-
cache = zero.((input, output))
162+
cache = nothing
161163

162-
FunctionOperator(
163-
op,
164-
op_adjoint,
165-
op_inverse,
166-
op_adjoint_inverse,
167-
traits,
168-
p,
169-
t,
170-
cache,
171-
)
164+
L = FunctionOperator(
165+
op,
166+
op_adjoint,
167+
op_inverse,
168+
op_adjoint_inverse,
169+
traits,
170+
p,
171+
t,
172+
cache,
173+
)
174+
175+
ifcache ? cache_operator(L, input, output) : L
172176
end
173177

174178
function update_coefficients(L::FunctionOperator, u, p, t)
@@ -200,6 +204,11 @@ function update_coefficients!(L::FunctionOperator, u, p, t)
200204
nothing
201205
end
202206

207+
function cache_self(L::FunctionOperator, u::AbstractVecOrMat, v::AbstractVecOrMat)
208+
@set! L.cache = zero.((u, v))
209+
L
210+
end
211+
203212
Base.size(L::FunctionOperator) = L.traits.size
204213
function Base.adjoint(L::FunctionOperator)
205214

@@ -306,7 +315,6 @@ function getops(L::FunctionOperator)
306315
end
307316

308317
#TODO - isconstant(L::FunctionOperator)
309-
iscached(::FunctionOperator) = true
310318
islinear(L::FunctionOperator) = L.traits.islinear
311319
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
312320
has_mul(L::FunctionOperator{iip}) where{iip} = true

src/interface.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ function iscached(L::AbstractSciMLOperator)
4040
return isset & all(iscached, getops(L))
4141
end
4242

43+
iscached(L) = true
4344
iscached(::Union{
4445
# LinearAlgebra
4546
AbstractMatrix,
@@ -57,11 +58,23 @@ Allocate caches for a SciMLOperator for fast evaluation
5758
5859
arguments:
5960
L :: AbstractSciMLOperator
60-
u :: AbstractVecOrMat argument to L
61+
in :: AbstractVecOrMat input prototype to L
62+
out :: (optional) AbstractVecOrMat output prototype to L
6163
"""
64+
cache_operator
65+
6266
cache_operator(L, u) = L
63-
cache_self(L::AbstractSciMLOperator, u::AbstractVecOrMat) = L
64-
cache_internals(L::AbstractSciMLOperator, u::AbstractVecOrMat) = L
67+
cache_operatro(L, u, v) = L
68+
cache_self(L::AbstractSciMLOperator, uv::AbstractVecOrMat...) = L
69+
cache_internals(L::AbstractSciMLOperator, uv::AbstractVecOrMat...) = L
70+
71+
function cache_operator(L::AbstractSciMLOperator,
72+
u::AbstractVecOrMat,
73+
v::AbstractVecOrMat)
74+
L = cache_self(L, u, v)
75+
L = cache_internals(L, u, v)
76+
L
77+
end
6578

6679
function cache_operator(L::AbstractSciMLOperator, u::AbstractVecOrMat)
6780
L = cache_self(L, u)

test/func.jl

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ K = 12
2929

3030
op_inverse=f1i,
3131

32+
ifcache = false,
33+
3234
islinear=true,
3335
opnorm=true,
3436
issymmetric=true,
@@ -41,6 +43,8 @@ K = 12
4143

4244
op_inverse=f2i,
4345

46+
ifcache = false,
47+
4448
islinear=true,
4549
opnorm=true,
4650
issymmetric=true,
@@ -69,6 +73,12 @@ K = 12
6973
@test has_ldiv(op2)
7074
@test has_ldiv!(op2)
7175

76+
@test !iscached(op1)
77+
@test !iscached(op2)
78+
79+
op1 = cache_operator(op1, u, A * u)
80+
op2 = cache_operator(op2, u, A * u)
81+
7282
@test iscached(op1)
7383
@test iscached(op2)
7484

@@ -87,10 +97,35 @@ end
8797

8898
f(du,u,p,t) = mul!(du, Diagonal(p*t), u)
8999

90-
op = FunctionOperator(f, u, u; p=zero(p), t=zero(t))
100+
L = FunctionOperator(f, u, u; p=zero(p), t=zero(t))
91101

92102
ans = @. u * p * t
93-
@test op(u,p,t) ans
94-
v=copy(u); @test op(v,u,p,t) ans
103+
@test L(u,p,t) ans
104+
v=copy(u); @test L(v,u,p,t) ans
105+
106+
# test that output isn't accidentally mutated by passing an internal cache.
107+
108+
A = Diagonal(p * t)
109+
u1 = rand(N, K)
110+
u2 = rand(N, K)
111+
112+
v1 = L * u1; @test v1 A * u1
113+
v2 = L * u2; @test v2 A * u2; @test v1 A * u1
114+
@test v1 + v2 A * (u1 + u2)
115+
116+
v1 .= 0.0
117+
v2 .= 0.0
118+
119+
mul!(v1, L, u1); @test v1 A * u1
120+
mul!(v2, L, u2); @test v2 A * u2; @test v1 A * u1
121+
@test v1 + v2 A * (u1 + u2)
122+
123+
v1 = rand(N, K); w1 = copy(v1)
124+
v2 = rand(N, K); w2 = copy(v2)
125+
a1, a2, b1, b2 = rand(4)
126+
127+
mul!(v1, L, u1, a1, b1); @test v1 a1*A*u1 + b1*w1
128+
mul!(v2, L, u2, a2, b2); @test v2 a2*A*u2 + b2*w2; @test v1 a1*A*u1 + b1*w1
129+
@test v1 + v2 (a1*A*u1 + b1*w1) + (a2*A*u2 + b2*w2)
95130
end
96131
#

0 commit comments

Comments
 (0)