Skip to content

Commit 0b4cfed

Browse files
committed
Merge branch 'master' of github.com:SciML/SciMLOperators.jl into batch
2 parents 071bf63 + c5dfb94 commit 0b4cfed

File tree

13 files changed

+724
-93
lines changed

13 files changed

+724
-93
lines changed

src/SciMLOperators.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,17 @@ $(TYPEDEF)
2828
"""
2929
abstract type AbstractSciMLLinearOperator{T} <: AbstractSciMLOperator{T} end
3030

31+
"""
32+
$(TYPEDEF)
33+
"""
34+
abstract type AbstractSciMLScalarOperator{T} <: AbstractSciMLLinearOperator{T} end
35+
3136
include("utils.jl")
3237
include("interface.jl")
3338
include("left.jl")
3439
include("multidim.jl")
3540

41+
include("scalar.jl")
3642
include("basic.jl")
3743
include("matrix.jl")
3844
include("batch.jl")

src/basic.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Base.convert(::Type{AbstractMatrix}, ::IdentityOperator{N}) where{N} = Diagonal(
1818
Base.size(::IdentityOperator{N}) where{N} = (N, N)
1919
Base.adjoint(A::IdentityOperator) = A
2020
Base.transpose(A::IdentityOperator) = A
21+
Base.conj(A::IdentityOperator) = A
2122
LinearAlgebra.opnorm(::IdentityOperator{N}, p::Real=2) where{N} = true
2223
for pred in (
2324
:issymmetric, :ishermitian, :isposdef,
@@ -108,6 +109,7 @@ Base.convert(::Type{AbstractMatrix}, ::NullOperator{N}) where{N} = Diagonal(zero
108109
Base.size(::NullOperator{N}) where{N} = (N, N)
109110
Base.adjoint(A::NullOperator) = A
110111
Base.transpose(A::NullOperator) = A
112+
Base.conj(A::NullOperator) = A
111113
LinearAlgebra.opnorm(::NullOperator{N}, p::Real=2) where{N} = false
112114
for pred in (
113115
:issymmetric, :ishermitian,
@@ -253,35 +255,26 @@ LinearAlgebra.ldiv!(α::ScalarOperator, u::AbstractVecOrMat) = ldiv!(α.val, u)
253255
(λ L)*(u) = λ * L(u)
254256
"""
255257
struct ScaledOperator{T,
256-
λType<:ScalarOperator,
257-
LType<:AbstractSciMLOperator,
258-
C,
258+
λType,
259+
LType,
259260
} <: AbstractSciMLOperator{T}
260261
λ::λType
261262
L::LType
262-
cache::C
263263

264-
function ScaledOperator::ScalarOperator{Tλ},
264+
function ScaledOperator::AbstractSciMLScalarOperator{Tλ},
265265
L::AbstractSciMLOperator{TL},
266-
cache = zeros(promote_type(Tλ,TL), 1),
267266
) where{Tλ,TL}
268267
T = promote_type(Tλ, TL)
269-
new{T,typeof(λ),typeof(L),typeof(cache)}(λ, L, cache)
268+
new{T,typeof(λ),typeof(L)}(λ, L)
270269
end
271270
end
272271

273-
ScalingNumberTypes = (
274-
:ScalarOperator,
275-
:Number,
276-
:UniformScaling,
277-
)
278-
279272
# constructors
280-
for T in ScalingNumberTypes[2:end]
273+
for T in SCALINGNUMBERTYPES[2:end]
281274
@eval ScaledOperator::$T, L::AbstractSciMLOperator) = ScaledOperator(ScalarOperator(λ), L)
282275
end
283276

284-
for T in ScalingNumberTypes
277+
for T in SCALINGNUMBERTYPES
285278
@eval function ScaledOperator::$T, L::ScaledOperator)
286279
λ = ScalarOperator(λ) * L.λ
287280
ScaledOperator(λ, L.L)
@@ -311,6 +304,7 @@ for op in (
311304
)
312305
@eval Base.$op(L::ScaledOperator) = ScaledOperator($op(L.λ), $op(L.L))
313306
end
307+
Base.conj(L::ScaledOperator) = conj(L.λ) * conj(L.L)
314308
LinearAlgebra.opnorm(L::ScaledOperator, p::Real=2) = abs(L.λ) * opnorm(L.L, p)
315309

316310
getops(L::ScaledOperator) = (L.λ, L.L,)
@@ -344,19 +338,19 @@ for fact in (
344338
end
345339

346340
# operator application, inversion
347-
for op in (
348-
:*, :\,
349-
)
350-
@eval Base.$op(L::ScaledOperator, x::AbstractVecOrMat) = $op(L.λ, $op(L.L, x))
351-
end
341+
Base.:*(L::ScaledOperator, u::AbstractVecOrMat) = L.λ * (L.L * u)
342+
Base.:\(L::ScaledOperator, u::AbstractVecOrMat) = L.λ \ (L.L \ u)
352343

353344
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat)
354-
mul!(v, L.L, u, L.λ.val, false)
345+
iszero(L.λ) && return lmul!(false, v)
346+
a = convert(Number, L.λ)
347+
mul!(v, L.L, u, a, false)
355348
end
356349

357350
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat, α, β)
358-
mul!(L.cache, [L.λ.val,], [α,])
359-
mul!(v, L.L, u, first(L.cache), β)
351+
iszero(L.λ) && return lmul!(β, v)
352+
a = convert(Number, L.λ*α)
353+
mul!(v, L.L, u, a, β)
360354
end
361355

362356
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::ScaledOperator, u::AbstractVecOrMat)
@@ -408,7 +402,7 @@ for op in (
408402
@eval Base.$op(A::AbstractSciMLOperator, B::AddedOperator) = AddedOperator(A, $op(B).ops...)
409403
@eval Base.$op(A::AddedOperator, B::AbstractSciMLOperator) = AddedOperator(A.ops..., $op(B))
410404

411-
for T in ScalingNumberTypes
405+
for T in SCALINGNUMBERTYPES
412406
@eval function Base.$op(L::AbstractSciMLOperator, λ::$T)
413407
@assert issquare(L)
414408
N = size(L, 1)
@@ -436,6 +430,7 @@ for op in (
436430
)
437431
@eval Base.$op(L::AddedOperator) = AddedOperator($op.(L.ops)...)
438432
end
433+
Base.conj(L::AddedOperator) = AddedOperator(conj.(L.ops))
439434

440435
getops(L::AddedOperator) = L.ops
441436
Base.iszero(L::AddedOperator) = all(iszero, getops(L))
@@ -546,11 +541,11 @@ for op in (
546541
end
547542

548543
# scalar operator
549-
@eval function Base.$op::ScalarOperator, L::ComposedOperator)
544+
@eval function Base.$op::AbstractSciMLScalarOperator, L::ComposedOperator)
550545
ScaledOperator(λ, L)
551546
end
552547

553-
@eval function Base.$op(L::ComposedOperator, λ::ScalarOperator)
548+
@eval function Base.$op(L::ComposedOperator, λ::AbstractSciMLScalarOperator)
554549
ScaledOperator(λ, L)
555550
end
556551
end
@@ -564,8 +559,12 @@ for op in (
564559
:adjoint,
565560
:transpose,
566561
)
567-
@eval Base.$op(L::ComposedOperator) = ComposedOperator($op.(reverse(L.ops))...)
562+
@eval Base.$op(L::ComposedOperator) = ComposedOperator(
563+
$op.(reverse(L.ops))...;
564+
cache=L.isset ? reverse(L.cache) : nothing,
565+
)
568566
end
567+
Base.conj(L::ComposedOperator) = ComposedOperator(conj.(L.ops); cache=L.cache)
569568
LinearAlgebra.opnorm(L::ComposedOperator) = prod(opnorm, L.ops)
570569

571570
getops(L::ComposedOperator) = L.ops
@@ -680,20 +679,23 @@ function InvertedOperator(L::AbstractSciMLOperator{T}; cache=nothing) where{T}
680679
end
681680

682681
Base.inv(L::AbstractSciMLOperator) = InvertedOperator(L)
682+
683+
Base.:\(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = inv(A) * B
684+
Base.:/(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = A * inv(B)
685+
683686
Base.convert(::Type{AbstractMatrix}, L::InvertedOperator) = inv(convert(AbstractMatrix, L.L))
684687

685688
Base.size(L::InvertedOperator) = size(L.L) |> reverse
686-
Base.adjoint(L::InvertedOperator) = InvertedOperator(L.L')
689+
Base.transpose(L::InvertedOperator) = InvertedOperator(transpose(L.L); cache = L.isset ? L.cache' : nothing)
690+
Base.adjoint(L::InvertedOperator) = InvertedOperator(adjoint(L.L); cache = L.isset ? L.cache' : nothing)
691+
Base.conj(L::InvertedOperator) = InvertedOperator(conj(L.L); cache=L.cache)
687692

688693
getops(L::InvertedOperator) = (L.L,)
689694

690695
has_mul!(L::InvertedOperator) = has_ldiv!(L.L)
691696
has_ldiv(L::InvertedOperator) = has_mul(L.L)
692697
has_ldiv!(L::InvertedOperator) = has_mul!(L.L)
693698

694-
Base.:\(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = inv(A) * B
695-
Base.:/(A::AbstractSciMLOperator, B::AbstractSciMLOperator) = A * inv(B)
696-
697699
@forward InvertedOperator.L (
698700
# LinearAlgebra
699701
LinearAlgebra.issymmetric,

src/func.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
Matrix free operators (given by a function)
44
"""
5-
struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
5+
mutable struct FunctionOperator{isinplace,T,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
66
""" Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
77
op::F
88
""" Adjoint operator"""
@@ -145,9 +145,10 @@ function FunctionOperator(op;
145145
end
146146

147147
function update_coefficients!(L::FunctionOperator, u, p, t)
148-
@set! L.p = p
149-
@set! L.t = t
150-
L
148+
L.p = p
149+
L.t = t
150+
151+
nothing
151152
end
152153

153154
Base.size(L::FunctionOperator) = L.traits.size
@@ -256,6 +257,8 @@ has_mul!(L::FunctionOperator{iip}) where{iip} = iip
256257
has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
257258
has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing)
258259

260+
# TODO - FunctionOperator, Base.conj, transpose
261+
259262
# operator application
260263
Base.:*(L::FunctionOperator{false}, u::AbstractVecOrMat) = L.op(u, L.p, L.t)
261264
Base.:\(L::FunctionOperator{false}, u::AbstractVecOrMat) = L.op_inverse(u, L.p, L.t)

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function update_coefficients!(L::AbstractSciMLOperator, u, p, t)
2323
for op in getops(L)
2424
update_coefficients!(op, u, p, t)
2525
end
26-
L
26+
nothing
2727
end
2828

2929
(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u)

src/matrix.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,17 @@ for op in (
3535
@eval function Base.$op(L::MatrixOperator) # TODO - test this thoroughly
3636
MatrixOperator(
3737
$op(L.A);
38-
update_func = (A,u,p,t) -> $op(L.update_func(L.A,u,p,t))
38+
update_func= (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t)) # TODO - test
3939
)
4040
end
4141
end
42+
Base.conj(L::MatrixOperator) = MatrixOperator(
43+
conj(L.A);
44+
update_func= (A,u,b,t) -> conj(L.update_func(conj(L.A),u,p,t))
45+
)
4246

4347
has_adjoint(A::MatrixOperator) = has_adjoint(A.A)
44-
update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func(L.A,u,p,t); L)
48+
update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func(L.A,u,p,t); nothing)
4549

4650
isconstant(L::MatrixOperator) = L.update_func == DEFAULT_UPDATE_FUNC
4751
Base.iszero(L::MatrixOperator) = iszero(L.A)
@@ -107,7 +111,6 @@ Like MatrixOperator, but stores a Factorization instead.
107111
108112
Supports left division and `ldiv!` when applied to an array.
109113
"""
110-
# diagonal, bidiagonal, adjoint(factorization)
111114
struct InvertibleOperator{T,FType} <: AbstractSciMLLinearOperator{T}
112115
F::FType
113116

@@ -149,7 +152,9 @@ end
149152

150153
# traits
151154
Base.size(L::InvertibleOperator) = size(L.F)
155+
Base.transpose(L::InvertibleOperator) = InvertibleOperator(transpose(L.F))
152156
Base.adjoint(L::InvertibleOperator) = InvertibleOperator(L.F')
157+
Base.conj(L::InvertibleOperator) = InvertibleOperator(conj(L.F))
153158
LinearAlgebra.opnorm(L::InvertibleOperator{T}, p=2) where{T} = one(T) / opnorm(L.F)
154159
LinearAlgebra.issuccess(L::InvertibleOperator) = issuccess(L.F)
155160

@@ -179,8 +184,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOr
179184
LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u)
180185

181186
"""
182-
L = AffineOperator(A, B, b)
187+
L = AffineOperator(A, B, b[; update_func])
183188
L(u) = A*u + B*b
189+
190+
Represents a time-dependent affine operator. The update function is called
191+
by `update_coefficients!` and is assumed to have the following signature:
192+
193+
update_func(b::AbstractArray,u,p,t) -> [modifies b]
184194
"""
185195
struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T}
186196
A::AType
@@ -238,7 +248,7 @@ end
238248
getops(L::AffineOperator) = (L.A, L.B, L.b)
239249
Base.size(L::AffineOperator) = size(L.A)
240250

241-
update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func(L.b,u,p,t); L)
251+
update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func(L.b,u,p,t); nothing)
242252

243253
islinear(::AffineOperator) = false
244254
Base.iszero(L::AffineOperator) = all(iszero, getops(L))

0 commit comments

Comments
 (0)