Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 3f860ba

Browse files
committed
Implement FNO
1 parent f367be8 commit 3f860ba

File tree

2 files changed

+45
-19
lines changed

2 files changed

+45
-19
lines changed

src/fourier.jl

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,29 @@ using FFTW
33
using Tullio
44

55
export
6-
SpectralConv1d
6+
SpectralConv1d,
7+
FNO
78

89
struct SpectralConv1d{T,S}
910
weight::T
1011
in_channel::S
1112
out_channel::S
1213
modes::S
14+
σ::Function
1315
end
1416

1517
function SpectralConv1d(
1618
ch::Pair{<:Integer,<:Integer},
17-
modes::Integer;
19+
modes::Integer,
20+
σ::Function=identity;
1821
init=Flux.glorot_uniform,
1922
T::DataType=Float32
2023
)
2124
in_chs, out_chs = ch
2225
scale = one(T) / (in_chs * out_chs)
2326
weights = scale * init(out_chs, in_chs, modes)
2427

25-
return SpectralConv1d(weights, in_chs, out_chs, modes)
28+
return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
2629
end
2730

2831
Flux.@functor SpectralConv1d
@@ -39,19 +42,37 @@ function (m::SpectralConv1d)(𝐱::AbstractArray)
3942

4043
𝐱_out = irfft(𝐱_padded , size(𝐱, 1), 1)
4144

42-
return 𝐱_out
45+
return m.σ.(𝐱_out)
4346
end
4447

45-
# function FNO(modes::Integer, width::Integer)
46-
# return Chain(
47-
# PermutedDimsArray(Dense(2, width),(2,1,3)),
48-
# relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
49-
# relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
50-
# relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
51-
# PermutedDimsArray(relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)), (0, 2, 1)),
52-
# Dense(width, 128, relu),
53-
# Dense(128, 1)
54-
# )
55-
# end
48+
function FourierBlock(
49+
ch::Pair{<:Integer,<:Integer},
50+
modes::Integer,
51+
σ::Function=identity
52+
)
53+
return Chain(
54+
Parallel(+,
55+
Conv((1, ), ch),
56+
SpectralConv1d(ch, modes)
57+
),
58+
x -> σ.(x)
59+
)
60+
end
61+
62+
function FNO()
63+
modes = 16
64+
ch = 64 => 64
65+
66+
return Chain(
67+
Conv((1, ), 2=>64),
68+
FourierBlock(ch, modes, relu),
69+
FourierBlock(ch, modes, relu),
70+
FourierBlock(ch, modes, relu),
71+
FourierBlock(ch, modes),
72+
Conv((1, ), 64=>128, relu),
73+
Conv((1, ), 128=>1),
74+
flatten
75+
)
76+
end
5677

5778
# loss(m::SpectralConv1d, x, x̂) = sum(abs2, x̂ .- m(x)) / len

test/fourier.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@ using Flux
22

33
@testset "fourier" begin
44
modes = 16
5-
width = 64
6-
ch = width => width
5+
ch = 64 => 64
6+
77
m = Chain(
8-
Conv((1, ), 2=>width),
8+
Conv((1, ), 2=>64),
99
SpectralConv1d(ch, modes)
1010
)
1111

1212
𝐱, _ = get_data()
13-
@show size(m(𝐱))
13+
@test size(m(𝐱)) == (1024, 64, 1000)
14+
end
15+
16+
@testset "FNO" begin
17+
𝐱, 𝐲 = get_data()
18+
@test size(FNO()(𝐱)) == size(𝐲)
1419
end

0 commit comments

Comments
 (0)