Skip to content

Commit 6bdba5e

Browse files
Merge pull request #40 from vpuri3/smallfixes
5 arg mul!, 2arg ldiv! for functionoperator, traits and comments
2 parents 934f874 + 941b515 commit 6bdba5e

File tree

3 files changed

+127
-25
lines changed

3 files changed

+127
-25
lines changed

src/interface.jl

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
#
22
###
3-
# (u,p,t) and (du,u,p,t) interface
3+
# Operator interface
44
###
55

6-
#=
7-
1. Function call and multiplication: L(du, u, p, t) for inplace and du = L(u, p, t) for
8-
out-of-place, meaning L*u and mul!(du, L, u).
9-
2. If the operator is not a constant, update it with (u,p,t). A mutating form, i.e.
10-
update_coefficients!(A,u,p,t) that changes the internal coefficients, and a
11-
out-of-place form B = update_coefficients(A,u,p,t).
12-
3. isconstant(A) trait for whether the operator is constant or not.
13-
4. islinear(A) trait for whether the operator is linear or not.
14-
=#
6+
"""
7+
Function call and multiplication:
8+
- L(du, u, p, t) for in-place operator evaluation,
9+
- du = L(u, p, t) for out-of-place operator evaluation
10+
11+
If the operator is not a constant, update it with (u,p,t). A mutating form, i.e.
12+
update_coefficients!(A,u,p,t) that changes the internal coefficients, and a
13+
out-of-place form B = update_coefficients(A,u,p,t).
14+
15+
"""
16+
function (::AbstractSciMLOperator) end
1517

1618
DEFAULT_UPDATE_FUNC(A,u,p,t) = A
1719

@@ -39,7 +41,7 @@ function cache_operator(L::AbstractSciMLOperator, u::AbstractVector)
3941
end
4042

4143
###
42-
# AbstractSciMLOperator Traits
44+
# Operator Traits
4345
###
4446

4547
Base.size(A::AbstractSciMLOperator, d::Integer) = d <= 2 ? size(A)[d] : 1
@@ -94,13 +96,24 @@ islinear(::Union{
9496
# Base
9597
Number,
9698

97-
# SciMLOperator
99+
# SciMLOperators
98100
AbstractSciMLLinearOperator,
99101
}
100102
) = true
101103

102104
has_mul(L) = true
103-
has_mul!(L) = true
105+
106+
has_mul!(L) = false
107+
has_mul!(::Union{
108+
# LinearAlgebra
109+
AbstractVector,
110+
AbstractMatrix,
111+
UniformScaling,
112+
113+
# Base
114+
Number,
115+
}
116+
) = true
104117

105118
has_ldiv(L) = false
106119
has_ldiv(::Union{
@@ -128,14 +141,17 @@ has_adjoint(::Union{
128141
# Base
129142
Number,
130143

131-
# SciMLOperator
144+
# SciMLOperators
132145
AbstractSciMLLinearOperator,
133146
}
134147
) = true
135148

136149
issquare(A) = size(A,1) === size(A,2)
137150
issquare(::Union{
151+
# LinearAlgebra
138152
UniformScaling,
153+
154+
# Base
139155
Number,
140156
}
141157
) = true
@@ -145,8 +161,10 @@ issquare(A...) = @. (&)(issquare(A)...)
145161
# default linear operator traits
146162
###
147163

148-
Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator) =
164+
function Base.:(==)(L1::AbstractSciMLOperator, L2::AbstractSciMLOperator)
165+
size(L1) != size(L2) && return false
149166
convert(AbstractMatrix, L1) == convert(AbstractMatrix, L1)
167+
end
150168

151169
LinearAlgebra.exp(L::AbstractSciMLLinearOperator,t) = exp(t*L)
152170
has_exp(L::AbstractSciMLLinearOperator) = true

src/sciml.jl

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@ Base.similar(L::MatrixOperator, ::Type{T}, dims::Dims) where{T} = MatrixOperator
1919

2020
# traits
2121
@forward MatrixOperator.A (
22-
issquare, has_ldiv, has_ldiv!
22+
LinearAlgebra.isreal,
23+
LinearAlgebra.issymmetric,
24+
LinearAlgebra.ishermitian,
25+
LinearAlgebra.isposdef,
26+
27+
issquare,
28+
has_ldiv,
29+
has_ldiv!,
2330
)
2431
Base.size(L::MatrixOperator) = size(L.A)
2532
Base.adjoint(L::MatrixOperator) = MatrixOperator(L.A'; update_func=(A,u,p,t)->L.update_func(L.A,u,p,t)')
@@ -118,6 +125,13 @@ for op in (
118125
end
119126
end
120127

128+
for op in (
129+
:*, :
130+
)
131+
@eval Base.$op(A::AbstractMatrix, B::AbstractSciMLOperator) = $op(MatrixOperator(A), B)
132+
@eval Base.$op(A::AbstractSciMLOperator, B::AbstractMatrix) = $op(A, MatrixOperator(B))
133+
end
134+
121135
""" Diagonal Operator """
122136
DiagonalOperator(u::AbstractVector) = MatrixOperator(Diagonal(u))
123137
LinearAlgebra.Diagonal(L::MatrixOperator) = MatrixOperator(Diagonal(L.A))
@@ -252,7 +266,7 @@ end
252266
"""
253267
Matrix free operators (given by a function)
254268
"""
255-
struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt} <: AbstractSciMLOperator{T}
269+
struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
256270
""" Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
257271
op::F
258272
""" Adjoint operator"""
@@ -267,11 +281,27 @@ struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt} <: AbstractSciMLOperato
267281
p::P
268282
""" Time """
269283
t::Tt
284+
""" Is cache set? """
285+
isset::Bool
286+
""" Cache """
287+
cache::C
288+
289+
function FunctionOperator(op,
290+
op_adjoint,
291+
op_inverse,
292+
op_adjoint_inverse,
293+
traits,
294+
p,
295+
t,
296+
isset,
297+
cache
298+
)
270299

271-
function FunctionOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t)
272300
iip = traits.isinplace
273301
T = traits.T
274302

303+
isset = cache !== nothing
304+
275305
new{iip,
276306
T,
277307
typeof(op),
@@ -281,11 +311,19 @@ struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt} <: AbstractSciMLOperato
281311
typeof(traits),
282312
typeof(p),
283313
typeof(t),
314+
typeof(cache),
284315
}(
285-
op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t,
316+
op,
317+
op_adjoint,
318+
op_inverse,
319+
op_adjoint_inverse,
320+
traits,
321+
p,
322+
t,
323+
isset,
324+
cache,
286325
)
287326
end
288-
289327
end
290328

291329
function FunctionOperator(op;
@@ -303,6 +341,8 @@ function FunctionOperator(op;
303341
p=nothing,
304342
t=nothing,
305343

344+
cache=nothing,
345+
306346
# traits
307347
opnorm=nothing,
308348
isreal=true,
@@ -331,6 +371,8 @@ function FunctionOperator(op;
331371
op_adjoint_inverse = op_inverse
332372
end
333373

374+
t = t isa Nothing ? zero(T) : t
375+
334376
traits = (;
335377
opnorm = opnorm,
336378
isreal = isreal,
@@ -343,6 +385,8 @@ function FunctionOperator(op;
343385
size = size,
344386
)
345387

388+
isset = cache !== nothing
389+
346390
FunctionOperator(
347391
op,
348392
op_adjoint,
@@ -351,6 +395,8 @@ function FunctionOperator(op;
351395
traits,
352396
p,
353397
t,
398+
isset,
399+
cache,
354400
)
355401
end
356402

@@ -361,7 +407,7 @@ function update_coefficients!(L::FunctionOperator, u, p, t)
361407
end
362408

363409
Base.size(L::FunctionOperator) = L.traits.size
364-
function Base.adjoint(L::FunctionOperator{iip,T}) where{iip,T}
410+
function Base.adjoint(L::FunctionOperator)
365411

366412
if ishermitian(L) | (isreal(L) & issymmetric(L))
367413
return L
@@ -382,7 +428,20 @@ function Base.adjoint(L::FunctionOperator{iip,T}) where{iip,T}
382428
p = L.p
383429
t = L.t
384430

385-
FuncitonOperator(op, op_adjoint, op_inverse, op_adjoint_inverse, traits, p, t)
431+
cache = issquare(L) ? cache : nothing
432+
isset = cache !== nothing
433+
434+
435+
FuncitonOperator(op,
436+
op_adjoint,
437+
op_inverse,
438+
op_adjoint_inverse,
439+
traits,
440+
p,
441+
t,
442+
isset,
443+
cache
444+
)
386445
end
387446

388447
function LinearAlgebra.opnorm(L::FunctionOperator, p)
@@ -409,11 +468,30 @@ has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothin
409468
Base.:*(L::FunctionOperator, u::AbstractVector) = L.op(u, L.p, L.t)
410469
Base.:\(L::FunctionOperator, u::AbstractVector) = L.op_inverse(u, L.p, L.t)
411470

471+
function cache_operator(L::FunctionOperator, u::AbstractVector)
472+
@set! L.cache = similar(u)
473+
L
474+
end
475+
412476
function LinearAlgebra.mul!(v::AbstractVector, L::FunctionOperator, u::AbstractVector)
413477
L.op(v, u, L.p, L.t)
414478
end
415479

480+
function LinearAlgebra.mul!(v::AbstractVector, L::FunctionOperator, u::AbstractVector, α, β)
481+
@assert L.isset "set up cache by calling cache_operator($L, $u)"
482+
copy!(L.cache, v)
483+
mul!(v, L, u)
484+
lmul!(α, v)
485+
axpy!(β, L.cache, v)
486+
end
487+
416488
function LinearAlgebra.ldiv!(v::AbstractVector, L::FunctionOperator, u::AbstractVector)
417489
L.op_inverse(v, u, L.p, L.t)
418490
end
491+
492+
function LinearAlgebra.ldiv!(L::FunctionOperator, u::AbstractVector)
493+
@assert L.isset "set up cache by calling cache_operator($L, $u)"
494+
copy!(L.cache, u)
495+
ldiv!(u, L, L.cache)
496+
end
419497
#

test/sciml.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ end
7272
u = rand(N)
7373
p = nothing
7474
t = 0.0
75+
α = rand()
76+
β = rand()
7577

7678
A = rand(N,N) |> Symmetric
7779
F = lu(A)
@@ -131,10 +133,14 @@ end
131133
@test !has_ldiv(op2)
132134
@test has_ldiv!(op2)
133135

134-
v = zero(u); @test A * u op1 * u mul!(v, op2, u)
135-
v = zero(u); @test A * u op1(u,p,t) op2(v,u,p,t)
136+
op2 = cache_operator(op2, u)
137+
138+
v = rand(N); @test A * u op1 * u mul!(v, op2, u)
139+
v = rand(N); @test A * u op1(u,p,t) op2(v,u,p,t)
140+
v = rand(N); w=copy(v); @test α*(A*u)+ β*w mul!(v, op2, u, α, β)
136141

137-
v = zero(u); @test A \ u op1 \ u ldiv!(v, op2, u)
142+
v = rand(N); @test A \ u op1 \ u ldiv!(v, op2, u)
143+
v = copy(u); @test A \ v ldiv!(op2, u)
138144
end
139145

140146
@testset "Operator Algebra" begin

0 commit comments

Comments
 (0)