@@ -3,26 +3,29 @@ using FFTW
3
3
using Tullio
4
4
5
5
export
6
- SpectralConv1d
6
+ SpectralConv1d,
7
+ FNO
7
8
8
9
struct SpectralConv1d{T,S}
9
10
weight:: T
10
11
in_channel:: S
11
12
out_channel:: S
12
13
modes:: S
14
+ σ:: Function
13
15
end
14
16
15
17
function SpectralConv1d (
16
18
ch:: Pair{<:Integer,<:Integer} ,
17
- modes:: Integer ;
19
+ modes:: Integer ,
20
+ σ:: Function = identity;
18
21
init= Flux. glorot_uniform,
19
22
T:: DataType = Float32
20
23
)
21
24
in_chs, out_chs = ch
22
25
scale = one (T) / (in_chs * out_chs)
23
26
weights = scale * init (out_chs, in_chs, modes)
24
27
25
- return SpectralConv1d (weights, in_chs, out_chs, modes)
28
+ return SpectralConv1d (weights, in_chs, out_chs, modes, σ )
26
29
end
27
30
28
31
Flux. @functor SpectralConv1d
@@ -39,19 +42,37 @@ function (m::SpectralConv1d)(𝐱::AbstractArray)
39
42
40
43
𝐱_out = irfft (𝐱_padded , size (𝐱, 1 ), 1 )
41
44
42
- return 𝐱_out
45
+ return m . σ .( 𝐱_out)
43
46
end
44
47
45
- # function FNO(modes::Integer, width::Integer)
46
- # return Chain(
47
- # PermutedDimsArray(Dense(2, width),(2,1,3)),
48
- # relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
49
- # relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
50
- # relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
51
- # PermutedDimsArray(relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)), (0, 2, 1)),
52
- # Dense(width, 128, relu),
53
- # Dense(128, 1)
54
- # )
55
- # end
48
+ function FourierBlock (
49
+ ch:: Pair{<:Integer,<:Integer} ,
50
+ modes:: Integer ,
51
+ σ:: Function = identity
52
+ )
53
+ return Chain (
54
+ Parallel (+ ,
55
+ Conv ((1 , ), ch),
56
+ SpectralConv1d (ch, modes)
57
+ ),
58
+ x -> σ .(x)
59
+ )
60
+ end
61
+
62
+ function FNO ()
63
+ modes = 16
64
+ ch = 64 => 64
65
+
66
+ return Chain (
67
+ Conv ((1 , ), 2 => 64 ),
68
+ FourierBlock (ch, modes, relu),
69
+ FourierBlock (ch, modes, relu),
70
+ FourierBlock (ch, modes, relu),
71
+ FourierBlock (ch, modes),
72
+ Conv ((1 , ), 64 => 128 , relu),
73
+ Conv ((1 , ), 128 => 1 ),
74
+ flatten
75
+ )
76
+ end
56
77
57
78
# loss(m::SpectralConv1d, x, x̂) = sum(abs2, x̂ .- m(x)) / len
0 commit comments