Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit d3b3483

Browse files
authored
Merge pull request #12 from foldfelis/doc
Doc
2 parents 3124660 + c4cac42 commit d3b3483

File tree

5 files changed

+109
-9
lines changed

5 files changed

+109
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1818
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1919

2020
[compat]
21-
CUDA = "3.3"
21+
CUDA = "3.4"
2222
CUDAKernels = "0.3"
2323
ChainRulesCore = "1.3"
2424
DataDeps = "0.7"

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,12 @@
1313
[ci link]: https://github.com/foldfelis/NeuralOperators.jl/actions/workflows/CI.yml
1414
[codecov badge]: https://codecov.io/gh/foldfelis/NeuralOperators.jl/branch/master/graph/badge.svg?token=JQH3MP1Y9R
1515
[codecov link]: https://codecov.io/gh/foldfelis/NeuralOperators.jl
16+
17+
[Neural Operator](https://github.com/zongyi-li/graph-pde) is a novel deep learning method to learned the mapping
18+
between infinite-dimensional spaces of functions introduced by [Zongyi Li](https://github.com/zongyi-li)et al.
19+
20+
In this project I temporarily provide the SpectralConv layer and the
21+
[Fourier Neural Operator](https://github.com/zongyi-li/fourier_neural_operator).
22+
For more information, please take a look at the
23+
[Fourier Neural Operator model](src/model.jl) and the [example](example/burgers.jl) of solving
24+
[Burgers' equation](https://www.wikiwand.com/en/Burgers%27_equation)

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
34
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"

src/fourier.jl

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,55 @@
11
export
22
SpectralConv,
3+
SpectralConvPerm,
34
FourierOperator
45

5-
struct SpectralConv{N, T, S, F}
6+
abstract type AbstractSpectralConv{N, T, S, F} end
7+
8+
struct SpectralConv{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
69
weight::T
710
in_channel::S
811
out_channel::S
912
modes::NTuple{N, S}
1013
σ::F
1114
end
1215

13-
struct SpectralConvPerm{N, T, S, F}
16+
struct SpectralConvPerm{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
1417
weight::T
1518
in_channel::S
1619
out_channel::S
1720
modes::NTuple{N, S}
1821
σ::F
1922
end
2023

24+
"""
25+
SpectralConv(
26+
ch, modes, σ=identity;
27+
init=c_glorot_uniform, permuted=false, T=ComplexF32
28+
)
29+
30+
## SpectralConv
31+
32+
* ``v(x)``: input
33+
* ``F``, ``F^{-1}``: Fourier transform, inverse fourier transform
34+
* ``L``: linear transform on the lower Fouier modes.
35+
36+
``v(x)`` -> ``F`` -> ``L`` -> ``F^{-1}``
37+
38+
## Example
39+
40+
```jldoctest
41+
julia> SpectralConv(2=>5, (16, ))
42+
SpectralConv(2 => 5, (16,), σ=identity)
43+
44+
julia> using Flux
45+
46+
julia> SpectralConv(2=>5, (16, ), relu)
47+
SpectralConv(2 => 5, (16,), σ=relu)
48+
49+
julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
50+
SpectralConvPerm(2 => 5, (16,), σ=relu)
51+
```
52+
"""
2153
function SpectralConv(
2254
ch::Pair{S, S},
2355
modes::NTuple{N, S},
@@ -38,10 +70,14 @@ end
3870
Flux.@functor SpectralConv
3971
Flux.@functor SpectralConvPerm
4072

41-
Base.ndims(::SpectralConv{N}) where {N} = N
42-
Base.ndims(::SpectralConvPerm{N}) where {N} = N
73+
Base.ndims(::AbstractSpectralConv{N}) where {N} = N
74+
75+
function Base.show(io::IO, l::AbstractSpectralConv)
76+
T = (l isa SpectralConv) ? SpectralConv : SpectralConvPerm
77+
print(io, "$(string(T))($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)))")
78+
end
4379

44-
function spectral_conv(m, 𝐱)
80+
function spectral_conv(m::AbstractSpectralConv, 𝐱::AbstractArray)
4581
n_dims = ndims(𝐱)
4682

4783
𝐱_fft = fft(Zygote.hook(real, 𝐱), 1:ndims(m)) # [x, in_chs, batch]
@@ -54,33 +90,86 @@ function spectral_conv(m, 𝐱)
5490
return m.σ.(𝐱_ifft)
5591
end
5692

57-
function (m::SpectralConv)(𝐱::AbstractArray)
93+
function (m::SpectralConv)(𝐱)
5894
𝐱ᵀ = permutedims(𝐱, (ntuple(i->i+1, ndims(m))..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
5995
𝐱_out = spectral_conv(m, 𝐱ᵀ) # [x, out_chs, batch]
6096
𝐱_outᵀ = permutedims(𝐱_out, (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]
6197

6298
return 𝐱_outᵀ
6399
end
64100

65-
function (m::SpectralConvPerm)(𝐱::AbstractArray)
101+
function (m::SpectralConvPerm)(𝐱)
66102
return spectral_conv(m, 𝐱) # [x, out_chs, batch]
67103
end
68104

69105
############
70106
# operator #
71107
############
72108

109+
"""
110+
FourierOperator(ch, modes, σ=identity; permuted=false)
111+
112+
## FourierOperator
113+
114+
* ``v(x)``: input
115+
* ``F``, ``F^{-1}``: Fourier transform, inverse fourier transform
116+
* ``L``: linear transform on the lower Fouier modes
117+
* ``D``: local linear transform
118+
119+
```
120+
┌ F -> L -> F¯¹ ┐
121+
v(x) -> ┤ ├ -> + -> σ
122+
└ D ┘
123+
```
124+
125+
## Example
126+
127+
```jldoctest
128+
julia> FourierOperator(2=>5, (16, ))
129+
Chain(
130+
Parallel(
131+
+,
132+
Dense(2, 5), # 15 parameters
133+
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
134+
),
135+
NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity),
136+
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
137+
138+
julia> using Flux
139+
140+
julia> FourierOperator(2=>5, (16, ), relu)
141+
Chain(
142+
Parallel(
143+
+,
144+
Dense(2, 5), # 15 parameters
145+
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
146+
),
147+
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
148+
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
149+
150+
julia> FourierOperator(2=>5, (16, ), relu, permuted=true)
151+
Chain(
152+
Parallel(
153+
+,
154+
Conv((1,), 2 => 5), # 15 parameters
155+
SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters
156+
),
157+
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
158+
) # Total: 3 arrays, 175 parameters, 1.871 KiB.
159+
```
160+
"""
73161
function FourierOperator(
74162
ch::Pair{S, S},
75163
modes::NTuple{N, S},
76164
σ=identity;
77165
permuted=false
78166
) where {S<:Integer, N}
79167
short_cut = permuted ? Conv(Tuple(ones(Int, length(modes))), ch) : Dense(ch.first, ch.second)
168+
activation_func(x) = σ.(x)
80169

81170
return Chain(
82171
Parallel(+, short_cut, SpectralConv(ch, modes, permuted=permuted)),
83-
x -> σ.(x)
172+
activation_func
84173
)
85174
end
86175

test/fourier.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
SpectralConv(ch, modes)
88
)
99
@test ndims(SpectralConv(ch, modes)) == 1
10+
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity)"
1011

1112
𝐱, _ = get_burgers_data(n=5)
1213
@test size(m(𝐱)) == (128, 1024, 5)

0 commit comments

Comments
 (0)