Skip to content

Commit 1d0c33b

Browse files
authored
Merge pull request #176 from vpuri3/oop_update
OOP update_coefficients(...)
2 parents 3560369 + 8a71724 commit 1d0c33b

File tree

11 files changed

+235
-91
lines changed

11 files changed

+235
-91
lines changed

src/basic.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ Base.conj(L::ScaledOperator) = conj(L.λ) * conj(L.L)
234234
Base.resize!(L::ScaledOperator, n::Integer) = (resize!(L.L, n); L)
235235
LinearAlgebra.opnorm(L::ScaledOperator, p::Real=2) = abs(L.λ) * opnorm(L.L, p)
236236

237+
function update_coefficients(L::ScaledOperator, u, p, t)
238+
@set! L.L = update_coefficients(L.L, u, p, t)
239+
@set! L.λ = update_coefficients(L.λ, u, p, t)
240+
241+
L
242+
end
237243
getops(L::ScaledOperator) = (L.λ, L.L,)
238244
isconstant(L::ScaledOperator) = isconstant(L.L) & isconstant(L.λ)
239245
islinear(L::ScaledOperator) = islinear(L.L)
@@ -386,6 +392,13 @@ function Base.resize!(L::AddedOperator, n::Integer)
386392
L
387393
end
388394

395+
function update_coefficients(L::AddedOperator, u, p, t)
396+
for i in 1:length(L.ops)
397+
@set! L.ops[i] = update_coefficients(L.ops[i], u, p, t)
398+
end
399+
L
400+
end
401+
389402
getops(L::AddedOperator) = L.ops
390403
islinear(L::AddedOperator) = all(islinear, getops(L))
391404
Base.iszero(L::AddedOperator) = all(iszero, getops(L))
@@ -532,6 +545,13 @@ end
532545

533546
LinearAlgebra.opnorm(L::ComposedOperator) = prod(opnorm, L.ops)
534547

548+
function update_coefficients(L::ComposedOperator, u, p, t)
549+
for i in 1:length(L.ops)
550+
@set! L.ops[i] = update_coefficients(L.ops[i], u, p, t)
551+
end
552+
L
553+
end
554+
535555
getops(L::ComposedOperator) = L.ops
536556
islinear(L::ComposedOperator) = all(islinear, L.ops)
537557
Base.iszero(L::ComposedOperator) = all(iszero, getops(L))
@@ -688,6 +708,12 @@ function Base.resize!(L::InvertedOperator, n::Integer)
688708
L
689709
end
690710

711+
function update_coefficients(L::InvertedOperator, u, p, t)
712+
@set! L.L = update_coefficients(L.L, u, p, t)
713+
714+
L
715+
end
716+
691717
getops(L::InvertedOperator) = (L.L,)
692718
islinear(L::InvertedOperator) = islinear(L.L)
693719

src/batch.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,32 @@ by `update_coefficients!` and is assumed to have the following signature:
88
99
update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
1010
"""
11-
struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T}
11+
struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T}
1212
diag::D
1313
update_func::F
14+
update_func!::F!
1415

15-
function BatchedDiagonalOperator(
16-
diag::AbstractArray;
17-
update_func=DEFAULT_UPDATE_FUNC
18-
)
16+
function BatchedDiagonalOperator(diag::AbstractArray, update_func, update_func!)
1917
new{
2018
eltype(diag),
2119
typeof(diag),
22-
typeof(update_func)
20+
typeof(update_func),
21+
typeof(update_func!),
2322
}(
24-
diag, update_func,
23+
diag, update_func, update_func!,
2524
)
2625
end
2726
end
2827

29-
function DiagonalOperator(u::AbstractArray; update_func=DEFAULT_UPDATE_FUNC)
30-
BatchedDiagonalOperator(u; update_func=update_func)
28+
function BatchedDiagonalOperator(diag::AbstractArray;
29+
update_func = DEFAULT_UPDATE_FUNC,
30+
update_func! = DEFAULT_UPDATE_FUNC)
31+
BatchedDiagonalOperator(diag, update_func, update_func!)
32+
end
33+
34+
function DiagonalOperator(u::AbstractArray; update_func = DEFAULT_UPDATE_FUNC,
35+
update_func! = DEFAULT_UPDATE_FUNC)
36+
BatchedDiagonalOperator(u; update_func = update_func, update_func! = update_func!)
3137
end
3238

3339
# traits
@@ -45,6 +51,21 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
4551
BatchedDiagonalOperator(diag; update_func=update_func)
4652
end
4753

54+
function update_coefficients(L::BatchedDiagonalOperator,u,p,t)
55+
@set! L.diag = L.update_func(L.diag,u,p,t)
56+
end
57+
update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func!(L.diag,u,p,t); L)
58+
59+
getops(L::BatchedDiagonalOperator) = (L.diag,)
60+
61+
function isconstant(L::BatchedDiagonalOperator)
62+
L.update_func == L.update_func! == DEFAULT_UPDATE_FUNC
63+
end
64+
islinear(::BatchedDiagonalOperator) = true
65+
has_adjoint(L::BatchedDiagonalOperator) = true
66+
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
67+
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)
68+
4869
LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
4970
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
5071
if isreal(L)
@@ -57,16 +78,6 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
5778
end
5879
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))
5980

60-
isconstant(L::BatchedDiagonalOperator) = L.update_func == DEFAULT_UPDATE_FUNC
61-
islinear(::BatchedDiagonalOperator) = true
62-
has_adjoint(L::BatchedDiagonalOperator) = true
63-
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
64-
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)
65-
66-
getops(L::BatchedDiagonalOperator) = (L.diag,)
67-
68-
update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func(L.diag,u,p,t); nothing)
69-
7081
# operator application
7182
Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
7283
Base.:\(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .\ u

src/func.jl

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -197,32 +197,26 @@ function FunctionOperator(op,
197197
end
198198

199199
function update_coefficients(L::FunctionOperator, u, p, t)
200-
op = update_coefficients(L.op, u, p, t)
201-
op_adjoint = update_coefficients(L.op_adjoint, u, p, t)
202-
op_inverse = update_coefficients(L.op_inverse, u, p, t)
203-
op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t)
200+
@set! L.op = update_coefficients(L.op, u, p, t)
201+
@set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t)
202+
@set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t)
203+
@set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t)
204204

205-
FunctionOperator(op,
206-
op_adjoint,
207-
op_inverse,
208-
op_adjoint_inverse,
209-
L.traits,
210-
p,
211-
t,
212-
L.cache
213-
)
205+
@set! L.p = p
206+
@set! L.t = t
207+
208+
L
214209
end
215210

216211
function update_coefficients!(L::FunctionOperator, u, p, t)
217-
ops = getops(L)
218-
for op in ops
212+
for op in getops(L)
219213
update_coefficients!(op, u, p, t)
220214
end
221215

222216
L.p = p
223217
L.t = t
224218

225-
nothing
219+
L
226220
end
227221

228222
function iscached(L::FunctionOperator)

src/interface.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,23 @@ function (::AbstractSciMLOperator) end
1717

1818
DEFAULT_UPDATE_FUNC(A,u,p,t) = A
1919

20-
update_coefficients!(L,u,p,t) = nothing
2120
update_coefficients(L,u,p,t) = L
21+
update_coefficients!(L,u,p,t) = L
22+
23+
function update_coefficients(L::AbstractSciMLOperator, u, p, t)
24+
@error """Out-of-place update method not implemented for $L.
25+
Please file an issue at https://github.com/SciML/SciMLOperators.jl
26+
with a minimal example."""
27+
end
28+
2229
function update_coefficients!(L::AbstractSciMLOperator, u, p, t)
2330
for op in getops(L)
2431
update_coefficients!(op, u, p, t)
2532
end
26-
nothing
33+
L
2734
end
2835

29-
(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u)
36+
(L::AbstractSciMLOperator)(u, p, t) = update_coefficients(L, u, p, t) * u
3037
(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u))
3138
(L::AbstractSciMLOperator)(du, u, p, t, α, β) = (update_coefficients!(L, u, p, t); mul!(du, L, u, α, β))
3239

0 commit comments

Comments
 (0)