Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit f367be8

Browse files
committed
implement SpectralConv1d
1 parent 9ab163c commit f367be8

File tree

7 files changed

+79
-82
lines changed

7 files changed

+79
-82
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
99
Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
1010
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1111
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
12+
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1213

1314
[compat]
1415
julia = "1.6"

src/NeuralOperators.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
module NeuralOperators
2-
using CUDA: reshape
3-
include("fourier_1d.jl")
2+
function __init__()
3+
register_datasets()
4+
end
5+
6+
include("preprocess.jl")
7+
include("fourier.jl")
48
end

src/fourier.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
using Flux
2+
using FFTW
3+
using Tullio
4+
5+
export
6+
SpectralConv1d
7+
8+
struct SpectralConv1d{T,S}
9+
weight::T
10+
in_channel::S
11+
out_channel::S
12+
modes::S
13+
end
14+
15+
function SpectralConv1d(
16+
ch::Pair{<:Integer,<:Integer},
17+
modes::Integer;
18+
init=Flux.glorot_uniform,
19+
T::DataType=Float32
20+
)
21+
in_chs, out_chs = ch
22+
scale = one(T) / (in_chs * out_chs)
23+
weights = scale * init(out_chs, in_chs, modes)
24+
25+
return SpectralConv1d(weights, in_chs, out_chs, modes)
26+
end
27+
28+
Flux.@functor SpectralConv1d
29+
30+
function (m::SpectralConv1d)(𝐱::AbstractArray)
31+
𝐱_fft = rfft(𝐱, 1) # [x, in_chs, batch]
32+
𝐱_selected = 𝐱_fft[1:m.modes, :, :] # [modes, in_chs, batch]
33+
34+
# [modes, out_chs, batch] <- [modes, in_chs, batch] [out_chs, in_chs, modes]
35+
@tullio 𝐱_weighted[m, o, b] := 𝐱_selected[m, i, b] * m.weight[o, i, m]
36+
37+
d = size(𝐱, 1) ÷ 2 + 1 - m.modes
38+
𝐱_padded = cat(𝐱_weighted, zeros(Float32, d, size(𝐱)[2:end]...), dims=1)
39+
40+
𝐱_out = irfft(𝐱_padded , size(𝐱, 1), 1)
41+
42+
return 𝐱_out
43+
end
44+
45+
# function FNO(modes::Integer, width::Integer)
46+
# return Chain(
47+
# PermutedDimsArray(Dense(2, width),(2,1,3)),
48+
# relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
49+
# relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
50+
# relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)),
51+
# PermutedDimsArray(relu(SpectralConv1d(width, width, modes) + Conv(width, width, 1)), (0, 2, 1)),
52+
# Dense(width, 128, relu),
53+
# Dense(128, 1)
54+
# )
55+
# end
56+
57+
# loss(m::SpectralConv1d, x, x̂) = sum(abs2, x̂ .- m(x)) / len

src/fourier_1d.jl

Lines changed: 0 additions & 76 deletions
This file was deleted.

test/fourier.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Flux
2+
3+
@testset "fourier" begin
4+
modes = 16
5+
width = 64
6+
ch = width => width
7+
m = Chain(
8+
Conv((1, ), 2=>width),
9+
SpectralConv1d(ch, modes)
10+
)
11+
12+
𝐱, _ = get_data()
13+
@show size(m(𝐱))
14+
end

test/fourier_1d.jl

Lines changed: 0 additions & 3 deletions
This file was deleted.

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true
55

66
@testset "NeuralOperators.jl" begin
77
include("preprocess.jl")
8-
# include("fourier.jl")
8+
include("fourier.jl")
99
end

0 commit comments

Comments
 (0)