Skip to content

Commit a01e08c

Browse files
Merge pull request #67 from vpuri3/func
Fix FunctionOperator caching
2 parents 7f70a6d + 1bfa2ce commit a01e08c

File tree

14 files changed

+1165
-808
lines changed

14 files changed

+1165
-808
lines changed

benchmarks/tensor.jl

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using SciMLOperators, LinearAlgebra, BenchmarkTools
2+
using SciMLOperators: IdentityOperator,
3+
4+
Id = IdentityOperator{12}()
5+
A = rand(12,12)
6+
B = rand(12,12)
7+
C = rand(12,12)
8+
9+
println("#===============================#")
10+
println("2D Tensor Products")
11+
println("#===============================#")
12+
13+
println("⊗(A, B)")
14+
15+
u = rand(12^2, 100)
16+
v = rand(12^2, 100)
17+
18+
T = (A, B)
19+
T = cache_operator(T, u)
20+
21+
@btime mul!($v, $T, $u)
22+
23+
println("⊗(I, B)")
24+
25+
u = rand(12^2, 100)
26+
v = rand(12^2, 100)
27+
28+
T = (Id, B)
29+
T = cache_operator(T, u)
30+
31+
@btime mul!($v, $T, $u)
32+
33+
println("⊗(A, I)")
34+
35+
u = rand(12^2, 100)
36+
v = rand(12^2, 100)
37+
38+
T = (A, Id)
39+
T = cache_operator(T, u)
40+
41+
@btime mul!($v, $T, $u)
42+
43+
println("#===============================#")
44+
println("3D Tensor Products")
45+
println("#===============================#")
46+
47+
println("⊗(⊗(A, B), C)")
48+
49+
u = rand(12^3, 100)
50+
v = rand(12^3, 100)
51+
52+
T = ((A, B), C)
53+
T = cache_operator(T, u)
54+
55+
mul!(v, T, u) # dunny
56+
@btime mul!($v, $T, $u); #
57+
58+
println("⊗(A, ⊗(B, C))")
59+
60+
u = rand(12^3, 100)
61+
v = rand(12^3, 100)
62+
63+
T = (A, (B, C))
64+
T = cache_operator(T, u)
65+
66+
mul!(v, T, u) # dunny
67+
@btime mul!($v, $T, $u); #
68+
69+
println("#===============================#")
70+
nothing

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ include("pages.jl")
44

55
makedocs(
66
sitename="SciMLOperators.jl",
7-
authors="Chris Rackauckas, Alex Jones",
7+
authors="Chris Rackauckas, Alex Jones, Vedant Puri",
88
modules=[SciMLOperators],
99
clean=true,doctest=false,
1010
format = Documenter.HTML(analytics = "UA-90474609-3",

docs/pages.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
pages = [
2-
"Home" => "index.md",
3-
"interface.md",
4-
"premade_operators.md"
5-
]
2+
"Home" => "index.md",
3+
"interface.md",
4+
"premade_operators.md",
5+
"tutorials/fftw.md"
6+
]

docs/src/tutorials/fftw.md

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Wrap a Fourier transform with SciMLOperators
2+
3+
In this tutorial, we will wrap a Fast Fourier Transform (FFT) in a SciMLOperator via the
4+
`FunctionOperator` interface. FFTs are commonly used algorithms for performing numerical
5+
interpolation and differentiation. In this example, we will use the FFT to compute the
6+
derivative of a function.
7+
8+
## Copy-Paste Code
9+
10+
```
11+
using SciMLOperators
12+
using LinearAlgebra, FFTW
13+
14+
L = 2π
15+
n = 256
16+
dx = L / n
17+
x = range(start=-L/2, stop=L/2-dx, length=n) |> Array
18+
19+
u = @. sin(5x)cos(7x);
20+
du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x);
21+
22+
transform = plan_rfft(x)
23+
k = Array(rfftfreq(n, 2π*n/L))
24+
25+
op_transform = FunctionOperator(
26+
(du,u,p,t) -> mul!(du, transform, u);
27+
isinplace=true,
28+
T=ComplexF64,
29+
size=(length(k),n),
30+
31+
input_prototype=x,
32+
output_prototype=im*k,
33+
34+
op_inverse = (du,u,p,t) -> ldiv!(du, transform, u)
35+
)
36+
37+
ik = im * DiagonalOperator(k)
38+
Dx = op_transform \ ik * op_transform
39+
40+
Dx = cache_operator(Dx, x)
41+
42+
@show ≈(Dx * u, du; atol=1e-8)
43+
@show ≈(mul!(copy(u), Dx, u), du; atol=1e-8)
44+
```
45+
46+
## Explanation
47+
48+
We load `SciMLOperators`, `LinearAlgebra`, and `FFTW` (short for Fastest Fourier Transform
49+
in the West), a common Fourier transform library. Next, we define an equispaced grid from
50+
-π to π, and write the function `u` that we intend to differentiate. Since this is a
51+
trivial example, we already know the derivative, `du` and write it down to later test our
52+
FFT wrapper.
53+
54+
```
55+
using SciMLOperators
56+
using LinearAlgebra, FFTW
57+
58+
L = 2π
59+
n = 256
60+
dx = L / n
61+
x = range(start=-L/2, stop=L/2-dx, length=n) |> Array
62+
63+
u = @. sin(5x)cos(7x);
64+
du = @. 5cos(5x)cos(7x) - 7sin(5x)sin(7x);
65+
66+
```
67+
68+
Now, we define the Fourier transform. Since our input is purely Real, we use the real
69+
Fast Fourier Transform. The funciton `plan_rfft` outputs a real fast fourier transform
70+
object that can be applied to inputs that are like `x` as follows: `xhat = transform * x`,
71+
and `LinearAlgebra.mul!(xhat, transform, x)`. We also get `k`, the frequency modes sampled by
72+
our finite grid, via the function `rfftfreq`.
73+
74+
```
75+
transform = plan_rfft(x)
76+
k = Array(rfftfreq(n, 2π*n/L))
77+
```
78+
79+
Now we are ready to define our wrapper for the FFT object. To `FunctionOperator`, we
80+
pass the in-place forward application of the transform,
81+
`(du,u,p,t) -> mul!(du, transform, u)`, its inverse application,
82+
`(du,u,p,t) -> ldiv!(du, transform, u)`, as well as input and output prototype vectors.
83+
We also set the flag `isinplace` to `true` to signal that we intend to use the operator
84+
in a non-allocating way, and pass in the element-type and size of the operator.
85+
86+
```
87+
op_transform = FunctionOperator(
88+
(du,u,p,t) -> mul!(du, transform, u);
89+
isinplace=true,
90+
T=ComplexF64,
91+
size=(length(k),n),
92+
93+
input_prototype=x,
94+
output_prototype=im*k,
95+
96+
op_inverse = (du,u,p,t) -> ldiv!(du, transform, u)
97+
)
98+
```
99+
100+
After wrapping the FFT with `FunctionOperator`, we are ready to compose it with other
101+
SciMLOperators. Below we form the derivative operator, and cache it via the function
102+
`cache_operator` that requires an input prototype. We can test our derivative operator
103+
both in-place, and out-of-place by comparing its output to the analytical derivative.
104+
105+
```
106+
ik = im * DiagonalOperator(k)
107+
Dx = op_transform \ ik * op_transform
108+
109+
@show ≈(Dx * u, du; atol=1e-8)
110+
@show ≈(mul!(copy(u), Dx, u), du; atol=1e-8)
111+
```
112+
113+
```
114+
≈(Dx * u, du; atol = 1.0e-8) = true
115+
≈(mul!(copy(u), Dx, u), du; atol = 1.0e-8) = true
116+
```
117+
118+

src/SciMLOperators.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,15 @@ $(TYPEDEF)
2828
"""
2929
abstract type AbstractSciMLLinearOperator{T} <: AbstractSciMLOperator{T} end
3030

31-
include("interface.jl")
3231
include("utils.jl")
32+
include("interface.jl")
3333
include("left.jl")
3434
include("multidim.jl")
35+
3536
include("basic.jl")
36-
include("sciml.jl")
37+
include("matrix.jl")
38+
include("func.jl")
39+
include("tensor.jl")
3740

3841
export ScalarOperator,
3942
MatrixOperator,

0 commit comments

Comments
 (0)