Skip to content

Commit b0e9d62

Browse files
Merge pull request #275 from SciML/operator_algebras
Fix operator algebras tutorial
2 parents 03320b8 + 0d2e37c commit b0e9d62

File tree

4 files changed

+92
-51
lines changed

4 files changed

+92
-51
lines changed

README.md

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,52 +41,61 @@ julia> Pkg.add("SciMLOperators")
4141
Let `M`, `D`, `F` be matrix-based, diagonal-matrix-based, and function-based
4242
`SciMLOperators` respectively.
4343

44-
```julia
44+
Let `M`, `D`, `F` be matrix-based, diagonal-matrix-based, and function-based
45+
`SciMLOperators` respectively.
46+
47+
```@example operator_algebra
48+
using SciMLOperators, LinearAlgebra
4549
N = 4
46-
f(u, p, t) = u .* u
47-
f(v, u, p, t) = v .= u .* u
50+
function f(v, u, p, t)
51+
u .* v
52+
end
53+
function f(w, v, u, p, t)
54+
w .= u .* v
55+
end
56+
57+
u = rand(4)
58+
p = nothing # parameter struct
59+
t = 0.0 # time
4860
4961
M = MatrixOperator(rand(N, N))
5062
D = DiagonalOperator(rand(N))
51-
F = FunctionOperator(f, zeros(N), zeros(N))
63+
F = FunctionOperator(f, zeros(N), zeros(N); u, p, t)
5264
```
5365

5466
Then, the following codes just work.
5567

56-
```julia
68+
```@example operator_algebra
5769
L1 = 2M + 3F + LinearAlgebra.I + rand(N, N)
5870
L2 = D * F * M'
5971
L3 = kron(M, D, F)
60-
L4 = M \ D
72+
L4 = lu(M) \ D
6173
L5 = [M; D]' * [M F; F D] * [F; D]
6274
```
6375

6476
Each `L#` can be applied to `AbstractVector`s of appropriate sizes:
6577

66-
```julia
67-
p = nothing # parameter struct
68-
t = 0.0 # time
69-
70-
u = rand(N)
71-
v = L1(u, p, t) # == L1 * u
78+
```@example operator_algebra
79+
v = rand(N)
80+
w = L1(v, u, p, t) # == L1 * v
7281
73-
u_kron = rand(N^3)
74-
v_kron = L3(u_kron, p, t) # == L3 * u_kron
82+
v_kron = rand(N^3)
83+
w_kron = L3(v_kron, u, p, t) # == L3 * v_kron
7584
```
7685

77-
For mutating operator evaluations, call `cache_operator` to generate
78-
in-place cache so the operation is nonallocating.
86+
For mutating operator evaluations, call `cache_operator` to generate an
87+
in-place cache, so the operation is nonallocating.
7988

80-
```julia
89+
```@example operator_algebra
8190
α, β = rand(2)
8291
8392
# allocate cache
8493
L2 = cache_operator(L2, u)
8594
L4 = cache_operator(L4, u)
8695
8796
# allocation-free evaluation
88-
L2(v, u, p, t) # == mul!(v, L2, u)
89-
L4(v, u, p, t, α, β) # == mul!(v, L4, u, α, β)
97+
L2(w, v, u, p, t) # == mul!(w, L2, v)
98+
L4(w, v, u, p, t, α, β) # == mul!(w, L4, v, α, β)
9099
```
91100

92101
## Roadmap
Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,41 @@
11
## [Demonstration of Operator Algebras and Kron](@id operator_algebras)
22

33
Let `M`, `D`, `F` be matrix-based, diagonal-matrix-based, and function-based
4-
`SciMLOperators` respectively.
4+
`SciMLOperators` respectively. Here are some examples of composing operators
5+
in order to build more complex objects and using their operations.
56

6-
```julia
7+
```@example operator_algebra
8+
using SciMLOperators, LinearAlgebra
79
N = 4
8-
f = (v, u, p, t) -> u .* v
10+
function f(v, u, p, t)
11+
u .* v
12+
end
13+
function f(w, v, u, p, t)
14+
w .= u .* v
15+
end
16+
17+
u = rand(4)
18+
p = nothing # parameter struct
19+
t = 0.0 # time
920
1021
M = MatrixOperator(rand(N, N))
1122
D = DiagonalOperator(rand(N))
12-
F = FunctionOperator(f, zeros(N), zeros(N))
23+
F = FunctionOperator(f, zeros(N), zeros(N); u, p, t)
1324
```
1425

1526
Then, the following codes just work.
1627

17-
```julia
28+
```@example operator_algebra
1829
L1 = 2M + 3F + LinearAlgebra.I + rand(N, N)
1930
L2 = D * F * M'
2031
L3 = kron(M, D, F)
21-
L4 = M \ D
32+
L4 = lu(M) \ D
2233
L5 = [M; D]' * [M F; F D] * [F; D]
2334
```
2435

2536
Each `L#` can be applied to `AbstractVector`s of appropriate sizes:
2637

27-
```julia
28-
p = nothing # parameter struct
29-
t = 0.0 # time
30-
31-
u = rand(N)
38+
```@example operator_algebra
3239
v = rand(N)
3340
w = L1(v, u, p, t) # == L1 * v
3441
@@ -39,7 +46,7 @@ w_kron = L3(v_kron, u, p, t) # == L3 * v_kron
3946
For mutating operator evaluations, call `cache_operator` to generate an
4047
in-place cache, so the operation is nonallocating.
4148

42-
```julia
49+
```@example operator_algebra
4350
α, β = rand(2)
4451
4552
# allocate cache
@@ -49,15 +56,4 @@ L4 = cache_operator(L4, u)
4956
# allocation-free evaluation
5057
L2(w, v, u, p, t) # == mul!(w, L2, v)
5158
L4(w, v, u, p, t, α, β) # == mul!(w, L4, v, α, β)
52-
```
53-
54-
The calling signature `L(v, u, p, t)`, for out-of-place evaluations, is
55-
equivalent to `L * v`, and the in-place evaluation `L(w, v, u, p, t, args...)`
56-
is equivalent to `LinearAlgebra.mul!(w, L, v, args...)`, where the arguments
57-
`u, p, t` are passed to `L` to update its state. More details are provided
58-
in the operator update section below.
59-
60-
The `(v, u, p, t)` calling signature is standardized over the `SciML`
61-
ecosystem and is flexible enough to support use cases such as time-evolution
62-
in ODEs, as well as sensitivity computation with respect to the parameter
63-
object `p`.
59+
```

src/func.jl

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,17 @@ function Base.:*(L::FunctionOperator{iip, true}, v::AbstractArray) where {iip}
734734
vec_output ? vec(W) : W
735735
end
736736

737+
function Base.:*(L::FunctionOperator{iip, false}, v::AbstractArray) where {iip}
738+
_sizecheck(L, v, nothing)
739+
V, _, vec_output = _unvec(L, v, nothing)
740+
741+
W, _ = L.cache
742+
W = copy(W)
743+
L.op(W, V, L.u, L.p, L.t; L.traits.kwargs...)
744+
745+
vec_output ? vec(W) : W
746+
end
747+
737748
function Base.:\(L::FunctionOperator{iip, true}, v::AbstractArray) where {iip}
738749
_sizecheck(L, nothing, v)
739750
_, V, vec_output = _unvec(L, nothing, v)
@@ -743,16 +754,30 @@ function Base.:\(L::FunctionOperator{iip, true}, v::AbstractArray) where {iip}
743754
vec_output ? vec(W) : W
744755
end
745756

757+
function Base.:\(L::FunctionOperator{iip, false}, v::AbstractArray) where {iip}
758+
_sizecheck(L, nothing, v)
759+
_, V, vec_output = _unvec(L, nothing, v)
760+
761+
W, _ = L.cache
762+
W = copy(W)
763+
L.op_inverse(W, V, L.u, L.p, L.t; L.traits.kwargs...)
764+
765+
vec_output ? vec(W) : W
766+
end
767+
746768
function LinearAlgebra.mul!(w::AbstractArray, L::FunctionOperator{true}, v::AbstractArray)
747769
_sizecheck(L, v, w)
748770
V, W, vec_output = _unvec(L, v, w)
749771
L.op(W, V, L.u, L.p, L.t; L.traits.kwargs...)
750772
vec_output ? vec(W) : W
751773
end
752774

753-
function LinearAlgebra.mul!(::AbstractArray, L::FunctionOperator{false}, ::AbstractArray,
775+
function LinearAlgebra.mul!(w::AbstractArray, L::FunctionOperator{false}, ::AbstractArray,
754776
args...)
755-
@error "LinearAlgebra.mul! not defined for out-of-place operator $L"
777+
_sizecheck(L, v, w)
778+
V, W, vec_output = _unvec(L, v, w)
779+
W .= L.op(V, L.u, L.p, L.t; L.traits.kwargs...)
780+
vec_output ? vec(W) : W
756781
end
757782

758783
function LinearAlgebra.mul!(w::AbstractArray, L::FunctionOperator{true, oop, false},
@@ -797,15 +822,26 @@ function LinearAlgebra.ldiv!(L::FunctionOperator{true}, v::AbstractArray)
797822
copy!(W, V)
798823
L.op_inverse(W, V, L.u, L.p, L.t; L.traits.kwargs...) # ldiv!(U, L, V)
799824

800-
vec_output ? vec(W) : W
825+
V .= W
826+
vec_output ? vec(V) : V
801827
end
802828

803-
function LinearAlgebra.ldiv!(v::AbstractArray, L::FunctionOperator{false}, u::AbstractArray)
804-
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
829+
function LinearAlgebra.ldiv!(w::AbstractArray, L::FunctionOperator{false}, v::AbstractArray)
830+
_sizecheck(L, v, w)
831+
W, V, _ = _unvec(L, w, v)
832+
833+
W .= L.op_inverse(V, L.u, L.p, L.t; L.traits.kwargs...)
834+
835+
w
805836
end
806837

807-
function LinearAlgebra.ldiv!(L::FunctionOperator{false}, u::AbstractArray)
808-
@error "LinearAlgebra.ldiv! not defined for out-of-place $L"
838+
function LinearAlgebra.ldiv!(L::FunctionOperator{false}, v::AbstractArray)
839+
_sizecheck(L, nothing, v)
840+
V, _, vec_output = _unvec(L, v, nothing)
841+
842+
V .= L.op_inverse(V, L.u, L.p, L.t; L.traits.kwargs...) # ldiv!(W, L, V)
843+
844+
vec_output ? vec(V) : V
809845
end
810846

811847
# Out-of-place: v is action vector, u is update vector

src/tensor.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function Base.:*(L::TensorProductOperator, v::AbstractVecOrMat)
172172
k = size(v, 2)
173173

174174
U = reshape(v, (ni, no * k))
175-
C = inner * U
175+
C = stack([inner * _v for _v in eachcol(U)])
176176

177177
V = outer_mul(L, v, C)
178178

0 commit comments

Comments
 (0)