Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit db8a1b6

Browse files
committed
refactor example
1 parent 07eef68 commit db8a1b6

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

example/burgers.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,23 @@ end
1313
m = FourierNeuralOperator() |> device
1414
loss(𝐱, 𝐲) = sum(abs2, 𝐲 .- m(𝐱)) / size(𝐱)[end]
1515

16+
n_train = 1800
17+
n_test = 200
18+
batchsize = 100
1619
𝐱, 𝐲 = get_burgers_data(n=2048)
1720

18-
n_train = 1000
1921
𝐱_train, 𝐲_train = 𝐱[:, :, 1:n_train], 𝐲[:, 1:n_train]
20-
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=20, shuffle=true)
22+
loader_train = Flux.DataLoader((𝐱_train, 𝐲_train), batchsize=batchsize, shuffle=true)
2123

22-
n_test = 100
2324
𝐱_test, 𝐲_test = 𝐱[:, :, end-n_test+1:end], 𝐲[:, end-n_test+1:end]
24-
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=20, shuffle=false)
25+
loader_test = Flux.DataLoader((𝐱_test, 𝐲_test), batchsize=batchsize, shuffle=false)
2526

26-
function loss_test()
27-
l = 0f0
28-
for (𝐱, 𝐲) in loader_test
29-
𝐱, 𝐲 = device(𝐱), device(𝐲)
30-
l += loss(𝐱, 𝐲)
31-
end
32-
@info "loss: $(l/length(loader_test))"
27+
function validate()
28+
validation_losses = [loss(device(𝐱), device(𝐲)) for (𝐱, 𝐲) in loader_test]
29+
@info "loss: $(sum(validation_losses)/length(loader_test))"
3330
end
3431

3532
data = [(𝐱, 𝐲) for (𝐱, 𝐲) in loader_train] |> device
3633
opt = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3))
37-
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=Flux.throttle(loss_test, 5)))
34+
call_back = Flux.throttle(validate, 5, leading=false, trailing=true)
35+
Flux.@epochs 500 @time(Flux.train!(loss, params(m), data, opt, cb=call_back))

0 commit comments

Comments
 (0)