Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 4c07880

Browse files
committed
add doc for FourierOperator
1 parent 16a76db commit 4c07880

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
34
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"

src/fourier.jl

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ end
4141
julia> SpectralConv(2=>5, (16, ))
4242
SpectralConv(2 => 5, (16,), σ=identity)
4343
44+
julia> using Flux
45+
4446
julia> SpectralConv(2=>5, (16, ), relu)
4547
SpectralConv(2 => 5, (16,), σ=relu)
4648
4749
julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
4850
SpectralConvPerm(2 => 5, (16,), σ=relu)
49-
````
51+
```
5052
"""
5153
function SpectralConv(
5254
ch::Pair{S, S},
@@ -104,17 +106,70 @@ end
104106
# operator #
105107
############
106108

109+
"""
110+
FourierOperator(ch, modes, σ=identity; permuted=false)
111+
112+
## FourierOperator
113+
114+
* ``v(x)``: input
115+
* ``F``, ``F^{-1}``: Fourier transform, inverse fourier transform
116+
* ``L``: linear transform on the lower Fouier modes
117+
* ``D``: local linear transform
118+
119+
```
120+
┌ F -> L -> F¯¹ ┐
121+
v(x) -> ┤ ├ -> + -> σ
122+
└ D ┘
123+
```
124+
125+
## Example
126+
127+
```jldoctest
128+
julia> FourierOperator(2=>5, (16, ))
129+
Chain(
130+
Parallel(
131+
+,
132+
Dense(2, 5), # 15 parameters
133+
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
134+
),
135+
NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity),
136+
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
137+
138+
julia> using Flux
139+
140+
julia> FourierOperator(2=>5, (16, ), relu)
141+
Chain(
142+
Parallel(
143+
+,
144+
Dense(2, 5), # 15 parameters
145+
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
146+
),
147+
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
148+
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
149+
150+
julia> FourierOperator(2=>5, (16, ), relu, permuted=true)
151+
Chain(
152+
Parallel(
153+
+,
154+
Conv((1,), 2 => 5), # 15 parameters
155+
SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters
156+
),
157+
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
158+
) # Total: 3 arrays, 175 parameters, 1.871 KiB.
159+
```
160+
"""
107161
function FourierOperator(
108162
ch::Pair{S, S},
109163
modes::NTuple{N, S},
110164
σ=identity;
111165
permuted=false
112166
) where {S<:Integer, N}
113167
short_cut = permuted ? Conv(Tuple(ones(Int, length(modes))), ch) : Dense(ch.first, ch.second)
168+
activation_func(x) = σ.(x)
114169

115170
return Chain(
116171
Parallel(+, short_cut, SpectralConv(ch, modes, permuted=permuted)),
117-
x -> σ.(x)
172+
activation_func
118173
)
119174
end
120175

test/fourier.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
SpectralConv(ch, modes)
88
)
99
@test ndims(SpectralConv(ch, modes)) == 1
10+
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity)"
1011

1112
𝐱, _ = get_burgers_data(n=5)
1213
@test size(m(𝐱)) == (128, 1024, 5)

0 commit comments

Comments
 (0)