Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 6266a94

Browse files
committed
separate get_dataloader
1 parent 41d5f16 commit 6266a94

File tree

3 files changed

+87
-68
lines changed

3 files changed

+87
-68
lines changed

example/FlowOverCircle/src/FlowOverCircle.jl

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module FlowOverCircle
33
using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
44
using NeuralOperators, Flux, GeometricFlux, Graphs
55
using CUDA, FluxTraining, BSON
6-
using GeometricFlux.GraphSignals: generate_coordinates
6+
using GeometricFlux.GraphSignals: generate_grid
77

88
function circle(n, m; Re = 250) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
99
# Set physical parameters
@@ -32,29 +32,21 @@ function gen_data(ts::AbstractRange)
3232
return 𝐩s
3333
end
3434

35-
function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
36-
ratio::Float64 = 0.95, batchsize = 100, flatten = false)
35+
function get_mno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
36+
ratio::Float64 = 0.95, batchsize = 100)
3737
data = gen_data(ts)
3838
𝐱, 𝐲 = data[:, :, :, 1:(end - 1)], data[:, :, :, 2:end]
3939
n = length(ts) - 1
40-
grid = generate_coordinates(𝐱[1, :, :, 1])
41-
grid = repeat(grid, outer = (1, 1, 1, n))
42-
x_with_grid = vcat(𝐱, grid)
43-
44-
if flatten
45-
x_with_grid = reshape(x_with_grid, size(x_with_grid, 1), :, n)
46-
𝐲 = reshape(𝐲, 1, :, n)
47-
end
4840

49-
data_train, data_test = splitobs(shuffleobs((x_with_grid, 𝐲)), at = ratio)
41+
data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at = ratio)
5042

5143
loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)
5244
loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)
5345

5446
return loader_train, loader_test
5547
end
5648

57-
function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
49+
function train_mno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
5850
if cuda && CUDA.has_cuda()
5951
device = gpu
6052
CUDA.allowscalar(false)
@@ -66,7 +58,7 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
6658

6759
model = MarkovNeuralOperator(ch = (1, 64, 64, 64, 64, 64, 1), modes = (24, 24),
6860
σ = gelu)
69-
data = get_dataloader()
61+
data = get_mno_dataloader()
7062
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))
7163
loss_func = l₂loss
7264

@@ -79,6 +71,61 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
7971
return learner
8072
end
8173

74+
function batch_featured_graph(data, graph, batchsize)
75+
tot_len = size(data)[end]
76+
bch_data = FeaturedGraph[]
77+
for i in 1:batchsize:tot_len
78+
bch_rng = (i + batchsize >= tot_len) ? (i:tot_len) : (i:(i + batchsize - 1))
79+
fg = FeaturedGraph(graph, nf = data[:, :, bch_rng], pf = data[:, :, bch_rng])
80+
push!(bch_data, fg)
81+
end
82+
83+
return bch_data
84+
end
85+
86+
function batch_data(data, batchsize)
87+
tot_len = size(data)[end]
88+
bch_data = Array{Float32, 3}[]
89+
for i in 1:batchsize:tot_len
90+
bch_rng = (i + batchsize >= tot_len) ? (i:tot_len) : (i:(i + batchsize - 1))
91+
push!(bch_data, data[:, :, bch_rng])
92+
end
93+
94+
return bch_data
95+
end
96+
97+
function get_gno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
98+
ratio::Float64 = 0.95, batchsize = 8)
99+
data = gen_data(ts)
100+
𝐱, 𝐲 = data[:, :, :, 1:(end - 1)], data[:, :, :, 2:end]
101+
n = length(ts) - 1
102+
103+
# generate graph
104+
graph = Graphs.grid(size(data)[2:3])
105+
106+
# add grid coordinates
107+
grid = generate_coordinates(𝐱[1, :, :, 1])
108+
grid = repeat(grid, outer = (1, 1, 1, n))
109+
𝐱 = vcat(𝐱, grid)
110+
111+
# flatten
112+
𝐱, 𝐲 = reshape(𝐱, size(𝐱, 1), :, n), reshape(𝐲, 1, :, n)
113+
114+
data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at = ratio)
115+
116+
batched_train_X = batch_featured_graph(data_train[1], graph, batchsize)
117+
batched_test_X = batch_featured_graph(data_test[1], graph, batchsize)
118+
batched_train_y = batch_data(data_train[2], batchsize)
119+
batched_test_y = batch_data(data_test[2], batchsize)
120+
121+
loader_train = DataLoader((batched_train_X, batched_train_y), batchsize = -1,
122+
shuffle = true)
123+
loader_test = DataLoader((batched_test_X, batched_test_y), batchsize = -1,
124+
shuffle = false)
125+
126+
return loader_train, loader_test
127+
end
128+
82129
function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
83130
if cuda && CUDA.has_cuda()
84131
device = gpu
@@ -91,23 +138,17 @@ function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
91138

92139
grid_dim = 2
93140
edge_dim = 2(grid_dim + 1)
94-
featured_graph = FeaturedGraph(grid([96, 64]))
95-
model = Chain(Flux.SkipConnection(Dense(grid_dim + 1, 16), vcat),
96-
# size(x) = (19, 6144, 8)
97-
WithGraph(featured_graph,
98-
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
99-
WithGraph(featured_graph,
100-
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
101-
WithGraph(featured_graph,
102-
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
103-
WithGraph(featured_graph,
104-
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16)),
105-
x -> x[1:end-3, :, :],
141+
model = Chain(GraphParallel(node_layer = Dense(grid_dim + 1, 16)),
142+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
143+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
144+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
145+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
146+
node_feature,
106147
Dense(16, 1))
107148

108149
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))
109150
loss_func = l₂loss
110-
data = get_dataloader(batchsize = 8, flatten = true)
151+
data = get_gno_dataloader()
111152
learner = Learner(model, data, optimiser, loss_func,
112153
ToDevice(device, device),
113154
Checkpointer(joinpath(@__DIR__, "../model/")))

src/graph_kernel.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,25 @@ end
4444

4545
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
4646
GraphSignals.check_num_nodes(el.N, X)
47-
# GraphSignals.check_num_edges(el.E, E)
47+
GraphSignals.check_num_edges(el.E, E)
4848
_, V, _ = GeometricFlux.propagate(l, el, E, X, nothing, mean, nothing, nothing)
4949

5050
return V
5151
end
5252

53-
# (wg::WithGraph{<:GraphKernel})(Vt::AbstractArray, E::AbstractArray) = wg(nothing, Vt, E), E
54-
55-
function (wg::WithGraph{<:GraphKernel})(input::AbstractArray)
56-
el = wg.graph
57-
Vt, X_with_grid = input[1:end-3, :, :], input[end-2:end, :, :]
53+
# For variable graph
54+
function (l::GraphKernel)(fg::AbstractFeaturedGraph)
55+
nf = node_feature(fg)
56+
pf = positional_feature(fg)
57+
GraphSignals.check_num_nodes(fg, nf)
58+
GraphSignals.check_num_nodes(fg, pf)
59+
el = GeometricFlux.GraphSignals.to_namedtuple(fg)
5860

5961
# node features + positional features as edge features
60-
E = vcat(GeometricFlux._gather(X_with_grid, el.xs),
61-
GeometricFlux._gather(X_with_grid, el.nbrs))
62+
ef = vcat(GeometricFlux._gather(pf, el.xs), GeometricFlux._gather(pf, el.nbrs))
63+
_, V, _ = GeometricFlux.propagate(l, el, ef, nf, nothing, mean, nothing, nothing)
6264

63-
return vcat(wg.layer(el, Vt, E), X_with_grid)
65+
return ConcreteFeaturedGraph(fg, nf = V)
6466
end
6567

6668
function Base.show(io::IO, l::GraphKernel)

test/graph_kernel.jl

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,16 @@
99
κ = Dense(2(coord_dim + channel), abs2(channel), relu)
1010
κ_in_dim, κ_out_dim = 2(coord_dim + channel), abs2(channel)
1111

12-
@testset "pass edge features" begin
13-
E = rand(Float32, 2(coord_dim + channel), ne(graph), batch_size)
14-
l = WithGraph(FeaturedGraph(graph), GraphKernel(κ, channel))
15-
@test repr(l.layer) ==
16-
"GraphKernel(Dense($κ_in_dim => $κ_out_dim, relu), channel=$channel)"
17-
@test size(l(𝐱, E)) == (channel, nv(graph), batch_size)
18-
19-
g = Zygote.gradient(() -> sum(l(𝐱, E)), Flux.params(l))
20-
@test length(g.grads) == 3
21-
end
22-
23-
@testset "pass positional features" begin
24-
pf = rand(Float32, coord_dim, nv(graph), batch_size)
25-
pf = vcat(𝐱, pf)
26-
fg = FeaturedGraph(graph)
27-
l = WithGraph(fg, GraphKernel(κ, channel))
28-
@test repr(l.layer) ==
29-
"GraphKernel(Dense($κ_in_dim => $κ_out_dim, relu), channel=$channel)"
30-
@test size(l(pf, 𝐱, nothing)) == (channel, nv(graph), batch_size)
31-
32-
g = Zygote.gradient(() -> sum(l(pf, 𝐱, nothing)), Flux.params(l))
33-
@test length(g.grads) == 4
34-
end
35-
36-
@testset "pass positional features by FeaturedGraph" begin
12+
@testset "layer without graph" begin
3713
pf = rand(Float32, coord_dim, nv(graph), batch_size)
3814
pf = vcat(𝐱, pf)
39-
fg = FeaturedGraph(graph, pf = pf)
40-
l = WithGraph(fg, GraphKernel(κ, channel))
41-
@test repr(l.layer) ==
42-
"GraphKernel(Dense($κ_in_dim => $κ_out_dim, relu), channel=$channel)"
43-
@test size(l(𝐱)) == (channel, nv(graph), batch_size)
15+
l = GraphKernel(κ, channel)
16+
fg = FeaturedGraph(graph, nf = 𝐱, pf = pf)
17+
fg_ = l(fg)
18+
@test size(node_feature(fg_)) == (channel, nv(graph), batch_size)
19+
@test_throws MethodError l(𝐱)
4420

45-
g = Zygote.gradient(() -> sum(l(𝐱)), Flux.params(l))
46-
@test length(g.grads) == 3
21+
g = gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
22+
@test length(g.grads) == 5
4723
end
4824
end

0 commit comments

Comments
 (0)