Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit ce679e4

Browse files
authored
Merge pull request #57 from SciML/training_process
Revise training process
2 parents f5a9017 + 3768aee commit ce679e4

31 files changed

+400
-548
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ docs/site/
2525
# environment.
2626
Manifest.toml
2727

28-
*.jld2
28+
*.bson

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1616
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1717

1818
[compat]
19-
CUDA = "3.8"
19+
CUDA = "3.9"
2020
CUDAKernels = "0.3, 0.4"
21-
ChainRulesCore = "1.13"
21+
ChainRulesCore = "1.14"
2222
FFTW = "1.4"
2323
Flux = "0.13"
2424
GeometricFlux = "0.11"

example/Burgers/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ name = "Burgers"
22
uuid = "5b053d85-f964-4905-ae31-99551cd8d3ad"
33

44
[deps]
5+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
56
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
67
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
78
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
9+
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
810
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
11+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
912
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
1013

1114
[extras]

example/Burgers/src/Burgers.jl

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,72 @@
11
module Burgers
22

3-
using NeuralOperators
4-
using Flux
5-
using CUDA
3+
using DataDeps, MAT, MLUtils
4+
using NeuralOperators, Flux
5+
using CUDA, FluxTraining, BSON
66

7-
include("data.jl")
87
include("Burgers_deeponet.jl")
98

9+
function register_burgers()
10+
register(DataDep(
11+
"Burgers",
12+
"""
13+
Burgers' equation dataset from
14+
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
15+
""",
16+
"http://www.med.cgu.edu.tw/NeuralOperators/Burgers_R10.zip",
17+
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
18+
post_fetch_method=unpack
19+
))
20+
end
21+
22+
function get_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
23+
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
24+
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
25+
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
26+
close(file)
27+
28+
x_loc_data = Array{T, 3}(undef, 2, grid_size, n)
29+
x_loc_data[1, :, :] .= reshape(repeat(LinRange(0, 1, grid_size), n), (grid_size, n))
30+
x_loc_data[2, :, :] .= x_data
31+
32+
return x_loc_data, reshape(y_data, 1, :, n)
33+
end
34+
35+
function get_dataloader(; ratio::Float64=0.9, batchsize=100)
36+
𝐱, 𝐲 = get_data(n=2048)
37+
data_train, data_test = splitobs((𝐱, 𝐲), at=ratio)
38+
39+
loader_train = DataLoader(data_train, batchsize=batchsize, shuffle=true)
40+
loader_test = DataLoader(data_test, batchsize=batchsize, shuffle=false)
41+
42+
return loader_train, loader_test
43+
end
44+
1045
__init__() = register_burgers()
1146

12-
function train()
13-
if has_cuda()
14-
@info "CUDA is on"
47+
function train(; cuda=true, η₀=1f-3, λ=1f-4, epochs=500)
48+
if cuda && CUDA.has_cuda()
1549
device = gpu
1650
CUDA.allowscalar(false)
51+
@info "Training on GPU"
1752
else
1853
device = cpu
54+
@info "Training on CPU"
1955
end
2056

21-
modes = (16, )
22-
ch = 64 => 64
23-
σ = gelu
24-
Transform = FourierTransform
25-
m = Chain(
26-
Dense(2, 64),
27-
OperatorKernel(ch, modes, Transform, σ),
28-
OperatorKernel(ch, modes, Transform, σ),
29-
OperatorKernel(ch, modes, Transform, σ),
30-
OperatorKernel(ch, modes, Transform),
31-
Dense(64, 128, σ),
32-
Dense(128, 1),
33-
flatten
34-
) |> device
35-
36-
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
37-
38-
loader_train, loader_test = get_dataloader()
39-
40-
function validate()
41-
validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test]
42-
@info "loss: $(sum(validation_losses)/length(loader_test))"
43-
end
57+
model = FourierNeuralOperator(ch=(2, 64, 64, 64, 64, 64, 128, 1), modes=(16, ), σ=gelu)
58+
data = get_dataloader()
59+
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.ADAM(η₀))
60+
loss_func = l₂loss
61+
62+
learner = Learner(
63+
model, data, optimiser, loss_func,
64+
ToDevice(device, device),
65+
)
66+
67+
fit!(learner, epochs)
4468

45-
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
46-
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
47-
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
48-
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))
69+
return learner
4970
end
5071

5172
end

example/Burgers/src/Burgers_deeponet.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
function get_data_don(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples))
2+
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
3+
x_data = collect(read(file, "a")[1:n, 1:Δsamples:end])
4+
y_data = collect(read(file, "u")[1:n, 1:Δsamples:end])
5+
close(file)
6+
7+
return x_data, y_data
8+
end
9+
110
function train_don(; n=300, cuda=true, learning_rate=0.001, epochs=400)
211
if cuda && has_cuda()
312
@info "Training on GPU"
@@ -20,7 +29,7 @@ function train_don(; n=300, cuda=true, learning_rate=0.001, epochs=400)
2029
opt = ADAM(learning_rate)
2130

2231
m = DeepONet((1024,1024,1024), (1,1024,1024), gelu, gelu) |> device
23-
32+
2433
loss(X, y, sensor) = Flux.Losses.mse(m(X, sensor), y)
2534
evalcb() = @show(loss(xval, yval, grid))
2635

example/Burgers/src/data.jl

Lines changed: 0 additions & 49 deletions
This file was deleted.

example/Burgers/test/data.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

example/Burgers/test/runtests.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@ using Burgers
22
using Test
33

44
@testset "Burgers" begin
5-
include("data.jl")
6-
include("deeponet.jl")
5+
xs, ys = Burgers.get_data(n=1000)
6+
7+
@test size(xs) == (2, 1024, 1000)
8+
@test size(ys) == (1, 1024, 1000)
9+
10+
learner = Burgers.train(epochs=10)
11+
loss = learner.cbstate.metricsepoch[ValidationPhase()][:Loss].values[end]
12+
@test loss < 0.1
13+
14+
# include("deeponet.jl")
715
end

example/DoublePendulum/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ name = "DoublePendulum"
22
uuid = "0c23c1c1-5f41-4617-a685-ac46aae913c3"
33

44
[deps]
5+
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
56
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
67
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
78
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
89
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
910
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
10-
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
11+
FluxTraining = "7bf95e4d-ca32-48da-9824-f0dc5310474f"
12+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1113
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
1214
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1315
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"

example/DoublePendulum/notebook/double_pendulum.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
### A Pluto.jl notebook ###
2-
# v0.16.0
2+
# v0.19.0
33

44
using Markdown
55
using InteractiveUtils
@@ -28,10 +28,7 @@ The data is provided by [IBM](https://developer.ibm.com/exchanges/data/all/doubl
2828
"
2929

3030
# ╔═╡ 5268feee-bda2-4612-9d4c-a1db424a11c7
31-
data_x, data_y, _, _ = DoublePendulum.preprocess(
32-
DoublePendulum.get_data(i=20),
33-
ratio=1
34-
);
31+
data_x, data_y = DoublePendulum.preprocess(DoublePendulum.get_data(i=20));
3532

3633
# ╔═╡ 4d0b08a4-8a54-41fd-997f-ad54d4c984cd
3734
m = DoublePendulum.get_model();

0 commit comments

Comments
 (0)