Skip to content

Commit e6d81d0

Browse files
committed
done
1 parent 0735593 commit e6d81d0

File tree

3 files changed

+31
-15
lines changed

3 files changed

+31
-15
lines changed

src/func.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function FunctionOperator(op,
9191
isinplace::Union{Nothing,Bool}=nothing,
9292
outofplace::Union{Nothing,Bool}=nothing,
9393
has_mul5::Union{Nothing,Bool}=nothing,
94-
cache::Union{Nothing, Bool, NTuple{2}}=nothing,
94+
cache::Union{Nothing, NTuple{2}}=nothing,
9595
T::Union{Type{<:Number},Nothing}=nothing,
9696

9797
op_adjoint=nothing,
@@ -113,8 +113,8 @@ function FunctionOperator(op,
113113
)
114114

115115
sz = (size(output, 1), size(input, 1))
116-
T = T isa Nothing ? promote_type(eltype.((input, output))...) : T
117-
t = t isa Nothing ? zero(real(T)) : t
116+
T = isnothing(T) ? promote_type(eltype.((input, output))...) : T
117+
t = isnothing(t) ? zero(real(T)) : t
118118

119119
isinplace = if isnothing(isinplace)
120120
static_hasmethod(op, typeof((output, input, p, t)))
@@ -141,14 +141,6 @@ function FunctionOperator(op,
141141
has_mul5
142142
end
143143

144-
need_cache = if isnothing(cache)
145-
true
146-
elseif cache isa Bool
147-
need_cache = cache
148-
cache = nothing
149-
need_cache
150-
end
151-
152144
if !isinplace & !outofplace
153145
@error "Please provide a funciton with signatures `op(u, p, t)` for applying
154146
the operator out-of-place, and/or the signature is `op(du, u, p, t)` for
@@ -181,7 +173,7 @@ function FunctionOperator(op,
181173
isinplace = isinplace,
182174
outofplace = outofplace,
183175
has_mul5 = has_mul5,
184-
need_cache = need_cache,
176+
ifcache = ifcache,
185177
T = T,
186178
size = sz,
187179
)
@@ -197,7 +189,11 @@ function FunctionOperator(op,
197189
cache,
198190
)
199191

200-
ifcache ? cache_operator(L, input, output) : L
192+
if ifcache & isnothing(L.cache)
193+
L = cache_operator(L, input, output)
194+
end
195+
196+
L
201197
end
202198

203199
function update_coefficients(L::FunctionOperator, u, p, t)
@@ -229,7 +225,13 @@ function update_coefficients!(L::FunctionOperator, u, p, t)
229225
nothing
230226
end
231227

228+
function iscached(L::FunctionOperator)
229+
L.traits.ifcache ? !isnothing(L.cache) : !L.traits.ifcache
230+
!isnothing(L.cache)
231+
end
232+
232233
function cache_self(L::FunctionOperator, u::AbstractVecOrMat, v::AbstractVecOrMat)
234+
L.traits.ifcache && @warn "you are allocating cache for a FunctionOperator for which ifcache = false."
233235
@set! L.cache = zero.((u, v))
234236
L
235237
end
@@ -373,7 +375,7 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::
373375
@error "LinearAlgebra.mul! not defined for out-of-place FunctionOperators"
374376
end
375377

376-
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat, α, β)
378+
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, false}, u::AbstractVecOrMat, α, β) where{oop}
377379
_, co = L.cache
378380

379381
copy!(co, v)
@@ -382,6 +384,10 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::A
382384
axpy!(β, co, v)
383385
end
384386

387+
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, true}, u::AbstractVecOrMat, α, β) where{oop}
388+
L.op(v, u, L.p, L.t, α, β)
389+
end
390+
385391
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
386392
L.op_inverse(v, u, L.p, L.t)
387393
end

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ getops(L) = ()
3939
function iscached(L::AbstractSciMLOperator)
4040

4141
has_cache = hasfield(typeof(L), :cache) # TODO - confirm this is static
42-
isset = has_cache ? L.cache !== nothing : true
42+
isset = has_cache ? !isnothing(L.cache) : true
4343

4444
return isset & all(iscached, getops(L))
4545
end

test/func.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ K = 12
1717

1818
A = rand(N,N) |> Symmetric
1919
F = lu(A)
20+
Ai = inv(A)
2021

2122
f1(u, p, t) = A * u
2223
f1i(u, p, t) = A \ u
2324

2425
f2(du, u, p, t) = mul!(du, A, u)
26+
f2(du, u, p, t, α, β) = mul!(du, A, u, α, β)
2527
f2i(du, u, p, t) = ldiv!(du, F, u)
28+
f2i(du, u, p, t, α, β) = mul!(du, Ai, u, α, β)
2629

2730
# out of place
2831
op1 = FunctionOperator(f1, u, A*u;
@@ -51,6 +54,7 @@ K = 12
5154
ishermitian=true,
5255
isposdef=true,
5356
)
57+
5458
@test issquare(op1)
5559
@test issquare(op2)
5660

@@ -76,6 +80,12 @@ K = 12
7680
@test !iscached(op1)
7781
@test !iscached(op2)
7882

83+
@test !op1.traits.has_mul5
84+
@test op2.traits.has_mul5
85+
86+
# 5-arg mul! (w/o cache)
87+
v = rand(N,K); w=copy(v); @test α*(A*u)+ β*w mul!(v, op2, u, α, β)
88+
7989
op1 = cache_operator(op1, u, A * u)
8090
op2 = cache_operator(op2, u, A * u)
8191

0 commit comments

Comments
 (0)