Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f121831

Browse files
committed
seal complex type between fft and ifft
1 parent f540576 commit f121831

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
1010
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
1212
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
13+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1314

1415
[compat]
1516
julia = "1.6"

src/fourier.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
using Flux
22
using FFTW
33
using Tullio
4+
using Zygote
45

56
export
67
SpectralConv1d,
78
FourierOperator,
89
FNO
910

10-
c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
11-
1211
struct SpectralConv1d{T, S}
1312
weight::T
1413
in_channel::S
@@ -17,8 +16,12 @@ struct SpectralConv1d{T, S}
1716
σ
1817
end
1918

19+
function c_glorot_uniform(dims...)
20+
return Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...) * im
21+
end
22+
2023
function SpectralConv1d(
21-
ch::Pair{<:Integer,<:Integer},
24+
ch::Pair{<:Integer, <:Integer},
2225
modes::Integer,
2326
σ=identity;
2427
init=c_glorot_uniform,
@@ -28,7 +31,10 @@ function SpectralConv1d(
2831
scale = one(T) / (in_chs * out_chs)
2932
weights = scale * init(out_chs, in_chs, modes)
3033

31-
return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
34+
return Chain(
35+
x -> Zygote.hook(real, x),
36+
SpectralConv1d(weights, in_chs, out_chs, modes, σ)
37+
)
3238
end
3339

3440
Flux.@functor SpectralConv1d
@@ -46,17 +52,17 @@ function (m::SpectralConv1d)(𝐱::AbstractArray)
4652

4753
𝐱_out = ifft(𝐱_padded, 2)
4854

49-
return m.σ.(𝐱_out)
55+
return m.σ.(real(𝐱_out))
5056
end
5157

5258
function FourierOperator(
53-
ch::Pair{<:Integer,<:Integer},
59+
ch::Pair{<:Integer, <:Integer},
5460
modes::Integer,
5561
σ=identity
5662
)
5763
return Chain(
5864
Parallel(+,
59-
Dense(ch.first, ch.second, init=c_glorot_uniform),
65+
Dense(ch.first, ch.second),
6066
SpectralConv1d(ch, modes)
6167
),
6268
x -> σ.(x)
@@ -66,16 +72,16 @@ end
6672
function FNO()
6773
modes = 16
6874
ch = 64 => 64
69-
σ = x -> @. log(1 + exp(x))
75+
σ = relu
7076

7177
return Chain(
72-
Dense(2, 64, init=c_glorot_uniform),
78+
Dense(2, 64),
7379
FourierOperator(ch, modes, σ),
7480
FourierOperator(ch, modes, σ),
7581
FourierOperator(ch, modes, σ),
7682
FourierOperator(ch, modes),
77-
Dense(64, 128, σ, init=c_glorot_uniform),
78-
Dense(128, 1, init=c_glorot_uniform),
83+
Dense(64, 128, σ),
84+
Dense(128, 1),
7985
flatten
8086
)
8187
end

test/fourier.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ using Flux
55
ch = 64 => 64
66

77
m = Chain(
8-
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
8+
Dense(2, 64),
99
SpectralConv1d(ch, modes)
1010
)
1111

1212
𝐱, _ = get_data()
1313
@test size(m(𝐱)) == (64, 1024, 1000)
1414

1515
T = Float32
16-
loss(x, y) = Flux.mse(real.(m(x)), y)
16+
loss(x, y) = Flux.mse(m(x), y)
1717
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
1818
Flux.train!(loss, params(m), data, Flux.ADAM())
1919
end
@@ -23,16 +23,15 @@ end
2323
ch = 64 => 64
2424

2525
m = Chain(
26-
Dense(2, 64, init=NeuralOperators.c_glorot_uniform),
26+
Dense(2, 64),
2727
FourierOperator(ch, modes)
2828
)
2929

3030
𝐱, _ = get_data()
3131
@test size(m(𝐱)) == (64, 1024, 1000)
3232

33-
T = Float32
34-
loss(x, y) = Flux.mse(real.(m(x)), y)
35-
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
33+
loss(x, y) = Flux.mse(m(x), y)
34+
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
3635
Flux.train!(loss, params(m), data, Flux.ADAM())
3736
end
3837

0 commit comments

Comments
 (0)