|
1 | 1 | module FlowOverCircle
|
2 | 2 |
|
3 | 3 | using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
|
4 |
| -using NeuralOperators, Flux |
| 4 | +using NeuralOperators, Flux, GeometricFlux, Graphs |
5 | 5 | using CUDA, FluxTraining, BSON
|
6 | 6 |
|
7 | 7 | function circle(n, m; Re=250) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
|
@@ -31,9 +31,16 @@ function gen_data(ts::AbstractRange)
|
31 | 31 | return 𝐩s
|
32 | 32 | end
|
33 | 33 |
|
34 |
| -function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100) |
| 34 | +function get_dataloader(; ts::AbstractRange=LinRange(100, 11000, 10000), ratio::Float64=0.95, batchsize=100, flatten=false) |
35 | 35 | data = gen_data(ts)
|
36 |
| - data_train, data_test = splitobs(shuffleobs((𝐱=data[:, :, :, 1:end-1], 𝐲=data[:, :, :, 2:end])), at=ratio) |
| 36 | + 𝐱, 𝐲 = data[:, :, :, 1:end-1], data[:, :, :, 2:end] |
| 37 | + n = length(ts) - 1 |
| 38 | + |
| 39 | + if flatten |
| 40 | + 𝐱, 𝐲 = reshape(𝐱, 1, :, n), reshape(𝐲, 1, :, n) |
| 41 | + end |
| 42 | + |
| 43 | + data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at=ratio) |
37 | 44 |
|
38 | 45 | loader_train = DataLoader(data_train, batchsize=batchsize, shuffle=true)
|
39 | 46 | loader_test = DataLoader(data_test, batchsize=batchsize, shuffle=false)
|
@@ -66,6 +73,40 @@ function train(; epochs=50)
|
66 | 73 | return learner
|
67 | 74 | end
|
68 | 75 |
|
| 76 | +function train_gno(; epochs=50) |
| 77 | + if has_cuda() |
| 78 | + @info "CUDA is on" |
| 79 | + device = gpu |
| 80 | + CUDA.allowscalar(false) |
| 81 | + else |
| 82 | + device = cpu |
| 83 | + end |
| 84 | + |
| 85 | + featured_graph = FeaturedGraph(grid([96, 64])) |
| 86 | + |
| 87 | + model = Chain( |
| 88 | + Dense(1, 16), |
| 89 | + WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)), |
| 90 | + WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)), |
| 91 | + WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)), |
| 92 | + WithGraph(featured_graph, GraphKernel(Dense(2*16, 16, gelu), 16)), |
| 93 | + Dense(16, 1), |
| 94 | + ) |
| 95 | + data = get_dataloader(batchsize=16, flatten=true) |
| 96 | + optimiser = Flux.Optimiser(WeightDecay(1f-4), Flux.ADAM(1f-3)) |
| 97 | + loss_func = l₂loss |
| 98 | + |
| 99 | + learner = Learner( |
| 100 | + model, data, optimiser, loss_func, |
| 101 | + ToDevice(device, device), |
| 102 | + Checkpointer(joinpath(@__DIR__, "../model/")) |
| 103 | + ) |
| 104 | + |
| 105 | + fit!(learner, epochs) |
| 106 | + |
| 107 | + return learner |
| 108 | +end |
| 109 | + |
69 | 110 | function get_model()
|
70 | 111 | model_path = joinpath(@__DIR__, "../model/")
|
71 | 112 | model_file = readdir(model_path)[end]
|
|
0 commit comments