Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit c797169

Browse files
authored
Merge pull request #5 from foldfelis/refactor
Refactor
2 parents 4c00dd0 + bbe1b30 commit c797169

File tree

9 files changed

+52
-49
lines changed

9 files changed

+52
-49
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
name: CI
2+
env:
3+
DATADEPS_ALWAYS_ACCEPT: true
24
on:
35
- push
46
- pull_request

src/NeuralOperators.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
11
module NeuralOperators
2+
using DataDeps
3+
using Fetch
4+
using MAT
5+
6+
using Flux
7+
using FFTW
8+
using Tullio
9+
using Zygote
10+
211
function __init__()
312
register_datasets()
413
end
514

6-
include("preprocess.jl")
15+
include("data.jl")
716
include("fourier.jl")
17+
include("model.jl")
818
end

src/preprocess.jl renamed to src/data.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
using DataDeps
2-
using Fetch
3-
using MAT
4-
51
export
6-
get_data
2+
get_burgers_data
73

84
function register_datasets()
95
register(DataDep(
@@ -19,7 +15,7 @@ function register_datasets()
1915
))
2016
end
2117

22-
function get_data(; n=1000, Δsamples=2^3, grid_size=div(2^13, Δsamples))
18+
function get_burgers_data(; n=1000, Δsamples=2^3, grid_size=div(2^13, Δsamples))
2319
file = matopen(joinpath(datadep"BurgersR10", "burgers_data_R10.mat"))
2420
x_data = collect(read(file, "a")[1:n, 1:Δsamples:end]')
2521
y_data = collect(read(file, "u")[1:n, 1:Δsamples:end]')

src/fourier.jl

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
using Flux
2-
using FFTW
3-
using Tullio
4-
using Zygote
5-
61
export
72
SpectralConv1d,
83
FourierOperator,
@@ -68,20 +63,3 @@ function FourierOperator(
6863
x -> σ.(x)
6964
)
7065
end
71-
72-
function FNO()
73-
modes = 16
74-
ch = 64 => 64
75-
σ = relu
76-
77-
return Chain(
78-
Dense(2, 64),
79-
FourierOperator(ch, modes, σ),
80-
FourierOperator(ch, modes, σ),
81-
FourierOperator(ch, modes, σ),
82-
FourierOperator(ch, modes),
83-
Dense(64, 128, σ),
84-
Dense(128, 1),
85-
flatten
86-
)
87-
end

src/model.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
export
2+
FourierNeuralOperator
3+
4+
function FourierNeuralOperator()
5+
modes = 16
6+
ch = 64 => 64
7+
σ = relu
8+
9+
return Chain(
10+
Dense(2, 64),
11+
FourierOperator(ch, modes, σ),
12+
FourierOperator(ch, modes, σ),
13+
FourierOperator(ch, modes, σ),
14+
FourierOperator(ch, modes),
15+
Dense(64, 128, σ),
16+
Dense(128, 1),
17+
flatten
18+
)
19+
end

test/preprocess.jl renamed to test/data.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
@testset "get data" begin
2-
xs, ys = get_data()
1+
@testset "get burgers data" begin
2+
xs, ys = get_burgers_data()
33

44
@test size(xs) == (2, 1024, 1000)
55
@test size(ys) == (1024, 1000)

test/fourier.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Flux
2-
31
@testset "SpectralConv1d" begin
42
modes = 16
53
ch = 64 => 64
@@ -9,7 +7,7 @@ using Flux
97
SpectralConv1d(ch, modes)
108
)
119

12-
𝐱, _ = get_data()
10+
𝐱, _ = get_burgers_data()
1311
@test size(m(𝐱)) == (64, 1024, 1000)
1412

1513
T = Float32
@@ -27,21 +25,10 @@ end
2725
FourierOperator(ch, modes)
2826
)
2927

30-
𝐱, _ = get_data()
28+
𝐱, _ = get_burgers_data()
3129
@test size(m(𝐱)) == (64, 1024, 1000)
3230

3331
loss(x, y) = Flux.mse(m(x), y)
3432
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
3533
Flux.train!(loss, params(m), data, Flux.ADAM())
3634
end
37-
38-
@testset "FNO" begin
39-
𝐱, 𝐲 = get_data()
40-
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
41-
@test size(FNO()(𝐱)) == size(𝐲)
42-
43-
m = FNO()
44-
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
45-
data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
46-
Flux.train!(loss, params(m), data, Flux.ADAM())
47-
end

test/model.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@testset "FourierNeuralOperator" begin
2+
m = FourierNeuralOperator()
3+
4+
𝐱, 𝐲 = get_burgers_data()
5+
𝐱, 𝐲 = Float32.(𝐱), Float32.(𝐲)
6+
@test size(m(𝐱)) == size(𝐲)
7+
8+
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
9+
data = [(𝐱[:, :, 1:5], 𝐲[:, 1:5])]
10+
Flux.train!(loss, params(m), data, Flux.ADAM())
11+
end

test/runtests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
using NeuralOperators
22
using Test
3-
4-
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
3+
using Flux
54

65
@testset "NeuralOperators.jl" begin
7-
include("preprocess.jl")
6+
include("data.jl")
87
include("fourier.jl")
8+
include("model.jl")
99
end

0 commit comments

Comments
 (0)