Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit fa998db

Browse files
authored
Merge pull request #14 from yuehhua/spectralconv
Refactor SpectralConv
2 parents ca71475 + b20ab6a commit fa998db

File tree

3 files changed

+55
-70
lines changed

3 files changed

+55
-70
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Currently, `FourierOperator` is provided in this work.
2424

2525
## Usage
2626

27-
```
27+
```julia
2828
function FourierNeuralOperator()
2929
modes = (16, )
3030
ch = 64 => 64
@@ -48,13 +48,13 @@ end
4848

4949
Or you can just call:
5050

51-
```
51+
```julia
5252
fno = FourierNeuralOperator()
5353
```
5454

5555
And then train as a Flux model.
5656

57-
```
57+
```julia
5858
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- fno(𝐱)) / size(𝐱)[end]
5959
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
6060
Flux.@epochs 50 Flux.train!(loss, params(m), data, opt)

src/fourier.jl

Lines changed: 40 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,8 @@
11
export
22
SpectralConv,
3-
SpectralConvPerm,
43
FourierOperator
54

6-
abstract type AbstractSpectralConv{N, T, S, F} end
7-
8-
struct SpectralConv{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
9-
weight::T
10-
in_channel::S
11-
out_channel::S
12-
modes::NTuple{N, S}
13-
σ::F
14-
end
15-
16-
struct SpectralConvPerm{N, T, S, F} <: AbstractSpectralConv{N, T, S, F}
5+
struct SpectralConv{P, N, T, S, F}
176
weight::T
187
in_channel::S
198
out_channel::S
@@ -33,21 +22,21 @@ end
3322
* `modes`: The Fourier modes to be preserved.
3423
* `σ`: Activation function.
3524
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
36-
data in the order of `(..., ch, batch)`, otherwise the order is `(ch, ..., batch)`.
25+
data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch, batch)`.
3726
3827
## Example
3928
4029
```jldoctest
4130
julia> SpectralConv(2=>5, (16, ))
42-
SpectralConv(2 => 5, (16,), σ=identity)
31+
SpectralConv(2 => 5, (16,), σ=identity, permuted=false)
4332
4433
julia> using Flux
4534
4635
julia> SpectralConv(2=>5, (16, ), relu)
47-
SpectralConv(2 => 5, (16,), σ=relu)
36+
SpectralConv(2 => 5, (16,), σ=relu, permuted=false)
4837
4938
julia> SpectralConv(2=>5, (16, ), relu, permuted=true)
50-
SpectralConvPerm(2 => 5, (16,), σ=relu)
39+
SpectralConv(2 => 5, (16,), σ=relu, permuted=true)
5140
```
5241
"""
5342
function SpectralConv(
@@ -61,23 +50,23 @@ function SpectralConv(
6150
in_chs, out_chs = ch
6251
scale = one(T) / (in_chs * out_chs)
6352
weights = scale * init(out_chs, in_chs, prod(modes))
53+
W = typeof(weights)
54+
F = typeof(σ)
6455

65-
L = permuted ? SpectralConvPerm : SpectralConv
66-
67-
return L(weights, in_chs, out_chs, modes, σ)
56+
return SpectralConv{permuted,N,W,S,F}(weights, in_chs, out_chs, modes, σ)
6857
end
6958

7059
Flux.@functor SpectralConv
71-
Flux.@functor SpectralConvPerm
7260

73-
Base.ndims(::AbstractSpectralConv{N}) where {N} = N
61+
Base.ndims(::SpectralConv{P,N}) where {P,N} = N
62+
63+
permuted(::SpectralConv{P}) where {P} = P
7464

75-
function Base.show(io::IO, l::AbstractSpectralConv)
76-
T = (l isa SpectralConv) ? SpectralConv : SpectralConvPerm
77-
print(io, "$(string(T))($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)))")
65+
function Base.show(io::IO, l::SpectralConv{P}) where {P}
66+
print(io, "SpectralConv($(l.in_channel) => $(l.out_channel), $(l.modes), σ=$(string(l.σ)), permuted=$P)")
7867
end
7968

80-
function spectral_conv(m::AbstractSpectralConv, 𝐱::AbstractArray)
69+
function spectral_conv(m::SpectralConv, 𝐱::AbstractArray)
8170
n_dims = ndims(𝐱)
8271

8372
𝐱_fft = fft(Zygote.hook(real, 𝐱), 1:ndims(m)) # [x, in_chs, batch]
@@ -90,22 +79,28 @@ function spectral_conv(m::AbstractSpectralConv, 𝐱::AbstractArray)
9079
return m.σ.(𝐱_ifft)
9180
end
9281

93-
function (m::SpectralConv)(𝐱)
82+
function (m::SpectralConv{false})(𝐱)
9483
𝐱ᵀ = permutedims(𝐱, (ntuple(i->i+1, ndims(m))..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
9584
𝐱_out = spectral_conv(m, 𝐱ᵀ) # [x, out_chs, batch]
9685
𝐱_outᵀ = permutedims(𝐱_out, (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]
9786

9887
return 𝐱_outᵀ
9988
end
10089

101-
function (m::SpectralConvPerm)(𝐱)
90+
function (m::SpectralConv{true})(𝐱)
10291
return spectral_conv(m, 𝐱) # [x, out_chs, batch]
10392
end
10493

10594
############
10695
# operator #
10796
############
10897

98+
struct FourierOperator{L, C, F}
99+
linear::L
100+
conv::C
101+
σ::F
102+
end
103+
109104
"""
110105
FourierOperator(ch, modes, σ=identity; permuted=false)
111106
@@ -115,42 +110,21 @@ end
115110
* `modes`: The Fourier modes to be preserved for spectral convolution.
116111
* `σ`: Activation function.
117112
* `permuted`: Whether the dim is permuted. If `permuted=true`, layer accepts
118-
data in the order of `(..., ch, batch)`, otherwise the order is `(ch, ..., batch)`.
113+
data in the order of `(ch, ..., batch)`, otherwise the order is `(..., ch, batch)`.
119114
120115
## Example
121116
122117
```jldoctest
123118
julia> FourierOperator(2=>5, (16, ))
124-
Chain(
125-
Parallel(
126-
+,
127-
Dense(2, 5), # 15 parameters
128-
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
129-
),
130-
NeuralOperators.var"#activation_func#14"{typeof(identity)}(identity),
131-
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
119+
FourierOperator(2 => 5, (16,), σ=identity, permuted=false)
132120
133121
julia> using Flux
134122
135123
julia> FourierOperator(2=>5, (16, ), relu)
136-
Chain(
137-
Parallel(
138-
+,
139-
Dense(2, 5), # 15 parameters
140-
SpectralConv(2 => 5, (16,), σ=identity), # 160 parameters
141-
),
142-
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
143-
) # Total: 3 arrays, 175 parameters, 1.668 KiB.
124+
FourierOperator(2 => 5, (16,), σ=relu, permuted=false)
144125
145126
julia> FourierOperator(2=>5, (16, ), relu, permuted=true)
146-
Chain(
147-
Parallel(
148-
+,
149-
Conv((1,), 2 => 5), # 15 parameters
150-
SpectralConvPerm(2 => 5, (16,), σ=identity), # 160 parameters
151-
),
152-
NeuralOperators.var"#activation_func#14"{typeof(relu)}(NNlib.relu),
153-
) # Total: 3 arrays, 175 parameters, 1.871 KiB.
127+
FourierOperator(2 => 5, (16,), σ=relu, permuted=true)
154128
```
155129
"""
156130
function FourierOperator(
@@ -159,15 +133,23 @@ function FourierOperator(
159133
σ=identity;
160134
permuted=false
161135
) where {S<:Integer, N}
162-
short_cut = permuted ? Conv(Tuple(ones(Int, length(modes))), ch) : Dense(ch.first, ch.second)
163-
activation_func(x) = σ.(x)
136+
linear = permuted ? Conv(Tuple(ones(Int, length(modes))), ch) : Dense(ch.first, ch.second)
137+
conv = SpectralConv(ch, modes; permuted=permuted)
164138

165-
return Chain(
166-
Parallel(+, short_cut, SpectralConv(ch, modes, permuted=permuted)),
167-
activation_func
168-
)
139+
return FourierOperator(linear, conv, σ)
140+
end
141+
142+
Flux.@functor FourierOperator
143+
144+
function Base.show(io::IO, l::FourierOperator)
145+
print(io, "FourierOperator($(l.conv.in_channel) => $(l.conv.out_channel), $(l.conv.modes), σ=$(string(l.σ)), permuted=$(permuted(l.conv)))")
169146
end
170147

148+
function (m::FourierOperator)(𝐱)
149+
return m.σ.(m.linear(𝐱) + m.conv(𝐱))
150+
end
151+
152+
171153
#########
172154
# utils #
173155
#########

test/fourier.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "SpectralConv1d" begin
1+
@testset "1D SpectralConv" begin
22
modes = (16, )
33
ch = 64 => 128
44

@@ -7,7 +7,7 @@
77
SpectralConv(ch, modes)
88
)
99
@test ndims(SpectralConv(ch, modes)) == 1
10-
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity)"
10+
@test repr(SpectralConv(ch, modes)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=false)"
1111

1212
𝐱, _ = get_burgers_data(n=5)
1313
@test size(m(𝐱)) == (128, 1024, 5)
@@ -17,7 +17,7 @@
1717
Flux.train!(loss, params(m), data, Flux.ADAM())
1818
end
1919

20-
@testset "SpectralConvPerm1d" begin
20+
@testset "permuted 1D SpectralConv" begin
2121
modes = (16, )
2222
ch = 64 => 128
2323

@@ -26,6 +26,7 @@ end
2626
SpectralConv(ch, modes, permuted=true)
2727
)
2828
@test ndims(SpectralConv(ch, modes, permuted=true)) == 1
29+
@test repr(SpectralConv(ch, modes, permuted=true)) == "SpectralConv(64 => 128, (16,), σ=identity, permuted=true)"
2930

3031
𝐱, _ = get_burgers_data(n=5)
3132
𝐱 = permutedims(𝐱, (2, 1, 3))
@@ -36,14 +37,15 @@ end
3637
Flux.train!(loss, params(m), data, Flux.ADAM())
3738
end
3839

39-
@testset "FourierOperator1d" begin
40+
@testset "1D FourierOperator" begin
4041
modes = (16, )
4142
ch = 64 => 128
4243

4344
m = Chain(
4445
Dense(2, 64),
4546
FourierOperator(ch, modes)
4647
)
48+
@test repr(FourierOperator(ch, modes)) == "FourierOperator(64 => 128, (16,), σ=identity, permuted=false)"
4749

4850
𝐱, _ = get_burgers_data(n=5)
4951
@test size(m(𝐱)) == (128, 1024, 5)
@@ -53,14 +55,15 @@ end
5355
Flux.train!(loss, params(m), data, Flux.ADAM())
5456
end
5557

56-
@testset "FourierOperatorPerm1d" begin
58+
@testset "permuted 1D FourierOperator" begin
5759
modes = (16, )
5860
ch = 64 => 128
5961

6062
m = Chain(
6163
Conv((1, ), 2=>64),
6264
FourierOperator(ch, modes, permuted=true)
6365
)
66+
@test repr(FourierOperator(ch, modes, permuted=true)) == "FourierOperator(64 => 128, (16,), σ=identity, permuted=true)"
6467

6568
𝐱, _ = get_burgers_data(n=5)
6669
𝐱 = permutedims(𝐱, (2, 1, 3))
@@ -71,7 +74,7 @@ end
7174
Flux.train!(loss, params(m), data, Flux.ADAM())
7275
end
7376

74-
@testset "SpectralConv2d" begin
77+
@testset "2D SpectralConv" begin
7578
modes = (16, 16)
7679
ch = 64 => 64
7780

@@ -89,7 +92,7 @@ end
8992
Flux.train!(loss, params(m), data, Flux.ADAM())
9093
end
9194

92-
@testset "SpectralConvPerm2d" begin
95+
@testset "permuted 2D SpectralConv" begin
9396
modes = (16, 16)
9497
ch = 64 => 64
9598

@@ -108,7 +111,7 @@ end
108111
Flux.train!(loss, params(m), data, Flux.ADAM())
109112
end
110113

111-
@testset "FourierOperator2d" begin
114+
@testset "2D FourierOperator" begin
112115
modes = (16, 16)
113116
ch = 64 => 64
114117

@@ -125,7 +128,7 @@ end
125128
Flux.train!(loss, params(m), data, Flux.ADAM())
126129
end
127130

128-
@testset "FourierOperatorPerm2d" begin
131+
@testset "permuted 2D FourierOperator" begin
129132
modes = (16, 16)
130133
ch = 64 => 64
131134

0 commit comments

Comments
 (0)