Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit bbe1b30

Browse files
committed
refactor model
1 parent af04bf9 commit bbe1b30

File tree

6 files changed

+32
-28
lines changed

6 files changed

+32
-28
lines changed

src/NeuralOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ module NeuralOperators
1414

1515
include("data.jl")
1616
include("fourier.jl")
17+
include("model.jl")
1718
end

src/fourier.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,3 @@ function FourierOperator(
6363
x -> σ.(x)
6464
)
6565
end
66-
67-
function FNO()
68-
modes = 16
69-
ch = 64 => 64
70-
σ = relu
71-
72-
return Chain(
73-
Dense(2, 64),
74-
FourierOperator(ch, modes, σ),
75-
FourierOperator(ch, modes, σ),
76-
FourierOperator(ch, modes, σ),
77-
FourierOperator(ch, modes),
78-
Dense(64, 128, σ),
79-
Dense(128, 1),
80-
flatten
81-
)
82-
end

src/model.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
export
2+
FourierNeuralOperator
3+
4+
function FourierNeuralOperator()
5+
modes = 16
6+
ch = 64 => 64
7+
σ = relu
8+
9+
return Chain(
10+
Dense(2, 64),
11+
FourierOperator(ch, modes, σ),
12+
FourierOperator(ch, modes, σ),
13+
FourierOperator(ch, modes, σ),
14+
FourierOperator(ch, modes),
15+
Dense(64, 128, σ),
16+
Dense(128, 1),
17+
flatten
18+
)
19+
end

test/fourier.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,3 @@ end
3232
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
3333
Flux.train!(loss, params(m), data, Flux.ADAM())
3434
end
35-
36-
@testset "FNO" begin
37-
𝐱, 𝐲 = get_burgers_data()
38-
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
39-
@test size(FNO()(𝐱)) == size(𝐲)
40-
41-
m = FNO()
42-
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
43-
data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
44-
Flux.train!(loss, params(m), data, Flux.ADAM())
45-
end

test/model.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@testset "FourierNeuralOperator" begin
2+
m = FourierNeuralOperator()
3+
4+
𝐱, 𝐲 = get_burgers_data()
5+
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
6+
@test size(m(𝐱)) == size(𝐲)
7+
8+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
9+
data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
10+
Flux.train!(loss, params(m), data, Flux.ADAM())
11+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ using Flux
55
@testset "NeuralOperators.jl" begin
66
include("data.jl")
77
include("fourier.jl")
8+
include("model.jl")
89
end

0 commit comments

Comments
 (0)