Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 6027fdd

Browse files
committed
to complex...
1 parent 9c4a5ac commit 6027fdd

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

src/fourier.jl

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ export
77
FourierOperator,
88
FNO
99

10-
struct SpectralConv1d{T,S}
10+
c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
11+
12+
struct SpectralConv1d{T, S}
1113
weight::T
1214
in_channel::S
1315
out_channel::S
@@ -19,8 +21,8 @@ function SpectralConv1d(
1921
ch::Pair{<:Integer,<:Integer},
2022
modes::Integer,
2123
σ=identity;
22-
init=Flux.glorot_uniform,
23-
T::DataType=Float32
24+
init=c_glorot_uniform,
25+
T::DataType=ComplexF32
2426
)
2527
in_chs, out_chs = ch
2628
scale = one(T) / (in_chs * out_chs)
@@ -32,16 +34,17 @@ end
3234
Flux.@functor SpectralConv1d
3335

3436
function (m::SpectralConv1d)(𝐱::AbstractArray)
35-
𝐱_fft = rfft(𝐱, 1) # [x, in_chs, batch]
36-
𝐱_selected = 𝐱_fft[1:m.modes, :, :] # [modes, in_chs, batch]
37+
𝐱_fft = fft(𝐱, 2) # [in_chs, x, batch]
38+
𝐱_selected = 𝐱_fft[:, 1:m.modes, :] # [in_chs, modes, batch]
3739

38-
# [modes, out_chs, batch] <- [modes, in_chs, batch] [out_chs, in_chs, modes]
39-
@tullio 𝐱_weighted[m, o, b] := 𝐱_selected[m, i, b] * m.weight[o, i, m]
40+
# [out_chs, modes, batch] <- [in_chs, modes, batch] [out_chs, in_chs, modes]
41+
@tullio 𝐱_weighted[o, m, b] := 𝐱_selected[i, m, b] * m.weight[o, i, m]
4042

41-
d = size(𝐱, 1) ÷ 2 + 1 - m.modes
42-
𝐱_padded = cat(𝐱_weighted, zeros(Float32, d, size(𝐱)[2:end]...), dims=1)
43+
s = size(𝐱_weighted)
44+
d = size(𝐱, 2) - m.modes
45+
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, s[1], d, s[3:end]...), dims=2)
4346

44-
𝐱_out = irfft(𝐱_padded , size(𝐱, 1), 1)
47+
𝐱_out = ifft(𝐱_padded, 2)
4548

4649
return m.σ.(𝐱_out)
4750
end
@@ -53,7 +56,7 @@ function FourierOperator(
5356
)
5457
return Chain(
5558
Parallel(+,
56-
Conv((1, ), ch),
59+
Dense(ch.first, ch.second, init=c_glorot_uniform),
5760
SpectralConv1d(ch, modes)
5861
),
5962
x -> σ.(x)
@@ -63,15 +66,16 @@ end
6366
function FNO()
6467
modes = 16
6568
ch = 64 => 64
69+
σ = x -> @. log(1 + exp(x))
6670

6771
return Chain(
68-
Conv((1, ), 2=>64),
69-
FourierOperator(ch, modes, relu),
70-
FourierOperator(ch, modes, relu),
71-
FourierOperator(ch, modes, relu),
72+
Dense(2, 64, init=c_glorot_uniform),
73+
FourierOperator(ch, modes, σ),
74+
FourierOperator(ch, modes, σ),
75+
FourierOperator(ch, modes, σ),
7276
FourierOperator(ch, modes),
73-
Conv((1, ), 64=>128, relu),
74-
Conv((1, ), 128=>1),
77+
Dense(64, 128, σ, init=c_glorot_uniform),
78+
Dense(128, 1, init=c_glorot_uniform),
7579
flatten
7680
)
7781
end

test/fourier.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,48 @@
11
using Flux
22

3-
@testset "fourier" begin
3+
@testset "SpectralConv1d" begin
44
modes = 16
55
ch = 64 => 64
66

77
m = Chain(
8-
Conv((1, ), 2=>64),
8+
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
99
SpectralConv1d(ch, modes)
1010
)
1111

1212
𝐱, _ = get_data()
13-
@test size(m(𝐱)) == (1024, 64, 1000)
13+
@test size(m(𝐱)) == (64, 1024, 1000)
14+
15+
T = Float32
16+
loss(x, y) = Flux.mse(real.(m(x)), y)
17+
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
18+
Flux.train!(loss, params(m), data, Flux.ADAM())
19+
end
20+
21+
@testset "FourierOperator" begin
22+
modes = 16
23+
ch = 64 => 64
24+
25+
m = Chain(
26+
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
27+
FourierOperator(ch, modes)
28+
)
29+
30+
𝐱, _ = get_data()
31+
@test size(m(𝐱)) == (64, 1024, 1000)
32+
33+
T = Float32
34+
loss(x, y) = Flux.mse(real.(m(x)), y)
35+
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
36+
Flux.train!(loss, params(m), data, Flux.ADAM())
1437
end
1538

1639
@testset "FNO" begin
1740
𝐱, 𝐲 = get_data()
1841
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
1942
@test size(FNO()(𝐱)) == size(𝐲)
2043

21-
# m = FNO()
22-
# loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
23-
# data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
24-
# Flux.train!(loss, params(m), data, Flux.ADAM())
44+
m = FNO()
45+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
46+
data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
47+
Flux.train!(loss, params(m), data, Flux.ADAM())
2548
end

0 commit comments

Comments
 (0)