Skip to content

Commit 855f257

Browse files
Merge pull request #82 from vpuri3/batch
Batched Diagonal Operator
2 parents c5dfb94 + 85252ea commit 855f257

File tree

7 files changed

+484
-43
lines changed

7 files changed

+484
-43
lines changed

docs/src/interface.md

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
11
# The AbstractSciMLOperator Interface
22

3-
## Formal Properties of DiffEqOperators
3+
## Formal Properties of SciMLOperators
44

55
These are the formal properties that an `AbstractSciMLOperator` should obey
66
for it to work in the solvers.
77

8-
## AbstractDiffEqOperator Interface Description
9-
10-
1. Function call and multiplication: `L(du,u,p,t)` for inplace and `du = L(u,p,t)` for
11-
out-of-place, meaning `L*u` and `mul!`.
12-
2. If the operator is not a constant, update it with `(u,p,t)`. A mutating form, i.e.
13-
`update_coefficients!(A,u,p,t)` that changes the internal coefficients, and a
14-
out-of-place form `B = update_coefficients(A,u,p,t)`.
15-
3. `isconstant(A)` trait for whether the operator is constant or not.
16-
17-
## AbstractDiffEqLinearOperator Interface Description
8+
1. An `AbstractSciMLOperator` represents a linear or nonlinear operator with input/output
9+
being `AbstractArray`s. Specifically, a SciMLOperator, `L`, of size `(M,N)` accepts
10+
input argument `u` with leading length `N`, i.e. `size(u, 1) == N`, and returns an
11+
`AbstractArray` of the same dimension with leading length `M`, i.e. `size(L * u, 1) == M`.
12+
2. SciMLOperators can be applied to an `AbstractArray` via overloaded `Base.*`, or
13+
the in-place `LinearAlgebra.mul!`. Additionally, operators are allowed to be time,
14+
or parameter dependent. The state of a SciMLOperator can be updated by calling
15+
the mutating function `update_coefficients!(L, u, p, t)` where `p` representes
16+
parameters, and `t`, time. Calling a SciMLOperator as `L(du, u, p, t)` or out-of-place
17+
`L(u, p, t)` will automatically update the state of `L` before applying it to `u`.
18+
`L(u, p, t)` is the same operation as `L(u, p, t) * u`.
19+
3. To support the update functionality, we have lazily implemented a comprehensive operator
20+
algebra. That means a user can add, subtract, scale, compose and invert SciMLOperators,
21+
and the state of the resultant operator would be updated as expected upon calling
22+
`L(du, u, p, t)` or `L(u, p, t)` so long as an update function is provided for the
23+
component operators.
24+
25+
## AbstractSciMLOperator Interface Description
1826

1927
1. `AbstractSciMLLinearOperator <: AbstractSciMLOperator`
20-
2. Can absorb under multiplication by a scalar. In all algorithms things like
21-
`dt*L` show up all the time, so the linear operator must be able to absorb
22-
such constants.
23-
4. `isconstant(A)` trait for whether the operator is constant or not.
24-
5. Optional: `diagonal`, `symmetric`, etc traits from LinearMaps.jl.
25-
6. Optional: `exp(A)`. Required for simple exponential integration.
26-
7. Optional: `expv(A,u,t) = exp(t*A)*u` and `expv!(v,A::AbstractSciMLOperator,u,t)`
28+
2. `AbstractSciMLScalarOperator <: AbstractSciMLLinearOperator`
29+
3. `isconstant(A)` trait for whether the operator is constant or not.
30+
4. Optional: `exp(A)`. Required for simple exponential integration.
31+
5. Optional: `expv(A,u,t) = exp(t*A)*u` and `expv!(v,A::AbstractSciMLOperator,u,t)`
2732
Required for sparse-saving exponential integration.
28-
8. Optional: factorizations. `ldiv!`, `factorize` et. al. This is only required
33+
6. Optional: factorizations. `ldiv!`, `factorize` et. al. This is only required
2934
for algorithms which use the factorization of the operator (Crank-Nicolson),
3035
and only for when the default linear solve is used.
3136

@@ -49,4 +54,4 @@ That same trick then can be used pretty much anywhere you would've had a linear
4954
the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov methods work for A being
5055
affine as well, and all sorts of things. Thus affine operators have no matrix representation but they
5156
are still compatible with essentially any Krylov method which would otherwise be compatible with
52-
matrix-free representations, hence their support in the SciMLOperators interface.
57+
matrix-free representations, hence their support in the SciMLOperators interface.

src/SciMLOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ include("multidim.jl")
4141
include("scalar.jl")
4242
include("basic.jl")
4343
include("matrix.jl")
44+
include("batch.jl")
4445
include("func.jl")
4546
include("tensor.jl")
4647

src/batch.jl

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#
2+
"""
3+
BatchedDiagonalOperator(diag, [; update_func])
4+
5+
Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
6+
Acts on `AbstractArray`s of the same size as `diag`. The update function is called
7+
by `update_coefficients!` and is assumed to have the following signature:
8+
9+
update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
10+
"""
11+
struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLLinearOperator{T}
12+
diag::D
13+
update_func::F
14+
15+
function BatchedDiagonalOperator(
16+
diag::AbstractArray;
17+
update_func=DEFAULT_UPDATE_FUNC
18+
)
19+
new{
20+
eltype(diag),
21+
typeof(diag),
22+
typeof(update_func)
23+
}(
24+
diag, update_func,
25+
)
26+
end
27+
end
28+
29+
function DiagonalOperator(u::AbstractArray; update_func=DEFAULT_UPDATE_FUNC)
30+
BatchedDiagonalOperator(u; update_func=update_func)
31+
end
32+
33+
# traits
34+
Base.size(L::BatchedDiagonalOperator) = (N = size(L.diag, 1); (N, N))
35+
Base.iszero(L::BatchedDiagonalOperator) = iszero(L.diag)
36+
Base.transpose(L::BatchedDiagonalOperator) = L
37+
Base.adjoint(L::BatchedDiagonalOperator) = conj(L)
38+
function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
39+
diag = conj(L.diag)
40+
update_func = if isreal(L)
41+
L.update_func
42+
else
43+
(L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t))
44+
end
45+
BatchedDiagonalOperator(diag; update_func=update_func)
46+
end
47+
48+
LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
49+
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
50+
if isreal(L)
51+
true
52+
else
53+
d = _vec(L.diag)
54+
D = Diagonal(d)
55+
ishermitian(d)
56+
end
57+
end
58+
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(_vec(L.diag)))
59+
60+
isconstant(L::BatchedDiagonalOperator) = L.update_func == DEFAULT_UPDATE_FUNC
61+
issquare(L::BatchedDiagonalOperator) = true
62+
has_adjoint(L::BatchedDiagonalOperator) = true
63+
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
64+
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)
65+
66+
getops(L::BatchedDiagonalOperator) = (L.diag,)
67+
68+
update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func(L.diag,u,p,t); nothing)
69+
70+
# operator application
71+
Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
72+
Base.:\(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .\ u
73+
74+
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::BatchedDiagonalOperator, u::AbstractVecOrMat)
75+
V = _vec(v)
76+
U = _vec(u)
77+
d = _vec(L.diag)
78+
D = Diagonal(d)
79+
mul!(V, D, U)
80+
81+
v
82+
end
83+
84+
function LinearAlgebra.mul!(v::AbstractVecOrMat, L::BatchedDiagonalOperator, u::AbstractVecOrMat, α, β)
85+
V = _vec(v)
86+
U = _vec(u)
87+
d = _vec(L.diag)
88+
D = Diagonal(d)
89+
mul!(V, D, U, α, β)
90+
91+
v
92+
end
93+
94+
function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::BatchedDiagonalOperator, u::AbstractVecOrMat)
95+
V = _vec(v)
96+
U = _vec(u)
97+
d = _vec(L.diag)
98+
D = Diagonal(d)
99+
ldiv!(V, D, U)
100+
101+
v
102+
end
103+
104+
function LinearAlgebra.ldiv!(L::BatchedDiagonalOperator, u::AbstractVecOrMat)
105+
U = _vec(u)
106+
d = _vec(L.diag)
107+
D = Diagonal(d)
108+
ldiv!(D, U)
109+
110+
u
111+
end
112+
#

src/interface.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,24 +89,8 @@ has_mul!(L::AbstractSciMLOperator) = false # mul!(du, L, u)
8989
has_ldiv(L::AbstractSciMLOperator) = false # du = L\u
9090
has_ldiv!(L::AbstractSciMLOperator) = false # ldiv!(du, L, u)
9191

92-
### AbstractSciMLLinearOperator Interface
93-
94-
#=
95-
1. AbstractSciMLLinearOperator <: AbstractSciMLOperator
96-
2. Can absorb under multiplication by a scalar. In all algorithms things like
97-
dt*L show up all the time, so the linear operator must be able to absorb
98-
such constants.
99-
4. isconstant(A) trait for whether the operator is constant or not.
100-
5. Optional: diagonal, symmetric, etc traits from LinearMaps.jl.
101-
6. Optional: exp(A). Required for simple exponential integration.
102-
7. Optional: expmv(A,u,p,t) = exp(t*A)*u and expmv!(v,A::SciMLOperator,u,p,t)
103-
Required for sparse-saving exponential integration.
104-
8. Optional: factorizations. A_ldiv_B, factorize et. al. This is only required
105-
for algorithms which use the factorization of the operator (Crank-Nicholson),
106-
and only for when the default linear solve is used.
107-
=#
108-
109-
# Extra standard assumptions
92+
### Extra standard assumptions
93+
11094
isconstant(L) = true
11195
isconstant(L::AbstractSciMLOperator) = all(isconstant, getops(L))
11296
#isconstant(L::AbstractSciMLOperator) = L.update_func = DEFAULT_UPDATE_FUNC

src/matrix.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#
12
"""
23
MatrixOperator(A[; update_func])
34
@@ -32,7 +33,7 @@ for op in (
3233
:adjoint,
3334
:transpose,
3435
)
35-
@eval function Base.$op(L::MatrixOperator)
36+
@eval function Base.$op(L::MatrixOperator) # TODO - test this thoroughly
3637
MatrixOperator(
3738
$op(L.A);
3839
update_func= (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t)) # TODO - test
@@ -82,13 +83,30 @@ LinearAlgebra.mul!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat,
8283
LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(v, L.A, u)
8384
LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u)
8485

85-
""" Diagonal Operator """
86-
function DiagonalOperator(u::AbstractVector; update_func=DEFAULT_UPDATE_FUNC)
86+
"""
87+
DiagonalOperator(diag, [; update_func])
88+
89+
Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
90+
The update function is called by `update_coefficients!` and is assumed to have
91+
the following signature:
92+
93+
update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
94+
95+
When `diag` is an `AbstractVector` of length N, `L=DiagonalOpeator(diag, ...)`
96+
can be applied to `AbstractArray`s with `size(u, 1) == N`. Each column of the `u`
97+
will be scaled by `diag`, as in `LinearAlgebra.Diagonal(diag) * u`.
98+
99+
When `diag` is a multidimensional array, `L = DiagonalOperator(diag, ...)` forms
100+
an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of `diag`.
101+
`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)`
102+
with leading length `size(u, 1) = N`.
103+
"""
104+
function DiagonalOperator(diag::AbstractVector; update_func=DEFAULT_UPDATE_FUNC)
87105
function diag_update_func(A, u, p, t)
88106
update_func(A.diag, u, p, t)
89107
A
90108
end
91-
MatrixOperator(Diagonal(u); update_func=diag_update_func)
109+
MatrixOperator(Diagonal(diag); update_func=diag_update_func)
92110
end
93111
LinearAlgebra.Diagonal(L::MatrixOperator) = MatrixOperator(Diagonal(L.A))
94112

@@ -205,7 +223,7 @@ end
205223

206224
function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator},
207225
B::Union{AbstractMatrix,AbstractSciMLOperator},
208-
b::AbstractVecOrMat;
226+
b::AbstractArray;
209227
update_func=DEFAULT_UPDATE_FUNC,
210228
)
211229
@assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,8 @@ function _mat_sizes(L::AbstractSciMLOperator, u::AbstractArray)
2121

2222
size_in, size_out
2323
end
24+
25+
dims(A) = length(size(A))
26+
dims(::AbstractArray{<:Any,N}) where{N} = N
27+
dims(::AbstractSciMLOperator) = 2
2428
#

0 commit comments

Comments
 (0)