Skip to content

Commit 941b515

Browse files
committed
tests done
1 parent 0723814 commit 941b515

File tree

2 files changed

+50
-19
lines changed

2 files changed

+50
-19
lines changed

src/sciml.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,21 @@ struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOpera
286286
""" Cache """
287287
cache::C
288288

289-
function FunctionOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t, isset, cache)
289+
function FunctionOperator(op,
290+
op_adjoint,
291+
op_inverse,
292+
op_adjoint_inverse,
293+
traits,
294+
p,
295+
t,
296+
isset,
297+
cache
298+
)
299+
290300
iip = traits.isinplace
291301
T = traits.T
292302

293-
isset = cache isa Nothing
303+
isset = cache !== nothing
294304

295305
new{iip,
296306
T,
@@ -303,7 +313,15 @@ struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOpera
303313
typeof(t),
304314
typeof(cache),
305315
}(
306-
op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t, isset, cache,
316+
op,
317+
op_adjoint,
318+
op_inverse,
319+
op_adjoint_inverse,
320+
traits,
321+
p,
322+
t,
323+
isset,
324+
cache,
307325
)
308326
end
309327
end
@@ -367,7 +385,7 @@ function FunctionOperator(op;
367385
size = size,
368386
)
369387

370-
isset = cache isa Nothing
388+
isset = cache !== nothing
371389

372390
FunctionOperator(
373391
op,
@@ -411,10 +429,19 @@ function Base.adjoint(L::FunctionOperator)
411429
t = L.t
412430

413431
cache = issquare(L) ? cache : nothing
414-
isset = cache isa Nothing
432+
isset = cache !== nothing
415433

416434

417-
FuncitonOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t, isset, cache)
435+
FuncitonOperator(op,
436+
op_adjoint,
437+
op_inverse,
438+
op_adjoint_inverse,
439+
traits,
440+
p,
441+
t,
442+
isset,
443+
cache
444+
)
418445
end
419446

420447
function LinearAlgebra.opnorm(L::FunctionOperator, p)
@@ -441,7 +468,7 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin
441468
Base.:*(L::FunctionOperator, u::AbstractVector) = L.op(u, L.p, L.t)
442469
Base.:\(L::FunctionOperator, u::AbstractVector) = L.op_inverse(u, L.p, L.t)
443470

444-
function update_cache(L::FunctionOperator, u::AbstractVector)
471+
function cache_operator(L::FunctionOperator, u::AbstractVector)
445472
@set! L.cache = similar(u)
446473
L
447474
end
@@ -451,22 +478,20 @@ function LinearAlgebra.mul!(v::AbstractVector, L::FunctionOperator, u::AbstractV
451478
end
452479

453480
function LinearAlgebra.mul!(v::AbstractVector, L::FunctionOperator, u::AbstractVector, α, β)
454-
try
455-
L.op(v, u, L.p, L.t, α, β)
456-
catch
457-
copy!(L.cache, v)
458-
mul!(v, u, L.p, L.t)
459-
lmul!(α, v)
460-
axpy!(β, L.cache, v)
461-
end
481+
@assert L.isset "set up cache by calling cache_operator($L, $u)"
482+
copy!(L.cache, v)
483+
mul!(v, L, u)
484+
lmul!(α, v)
485+
axpy!(β, L.cache, v)
462486
end
463487

464488
function LinearAlgebra.ldiv!(v::AbstractVector, L::FunctionOperator, u::AbstractVector)
465489
L.op_inverse(v, u, L.p, L.t)
466490
end
467491

468492
function LinearAlgebra.ldiv!(L::FunctionOperator, u::AbstractVector)
493+
@assert L.isset "set up cache by calling cache_operator($L, $u)"
469494
copy!(L.cache, u)
470-
L.op_inverse(u, L.cache, L.p, L.t)
495+
ldiv!(u, L, L.cache)
471496
end
472497
#

test/sciml.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ end
7272
u = rand(N)
7373
p = nothing
7474
t = 0.0
75+
α = rand()
76+
β = rand()
7577

7678
A = rand(N,N) |> Symmetric
7779
F = lu(A)
@@ -131,10 +133,14 @@ end
131133
@test !has_ldiv(op2)
132134
@test has_ldiv!(op2)
133135

134-
v = zero(u); @test A * u op1 * u mul!(v, op2, u)
135-
v = zero(u); @test A * u op1(u,p,t) op2(v,u,p,t)
136+
op2 = cache_operator(op2, u)
137+
138+
v = rand(N); @test A * u op1 * u mul!(v, op2, u)
139+
v = rand(N); @test A * u op1(u,p,t) op2(v,u,p,t)
140+
v = rand(N); w=copy(v); @test α*(A*u)+ β*w mul!(v, op2, u, α, β)
136141

137-
v = zero(u); @test A \ u op1 \ u ldiv!(v, op2, u)
142+
v = rand(N); @test A \ u op1 \ u ldiv!(v, op2, u)
143+
v = copy(u); @test A \ v ldiv!(op2, u)
138144
end
139145

140146
@testset "Operator Algebra" begin

0 commit comments

Comments
 (0)