Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 60b7726

Browse files
authored
Merge pull request #79 from yuehhua/gno
Fix GNO example
2 parents 54602e6 + c51ef55 commit 60b7726

File tree

3 files changed

+76
-28
lines changed

3 files changed

+76
-28
lines changed

example/FlowOverCircle/src/FlowOverCircle.jl

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module FlowOverCircle
33
using WaterLily, LinearAlgebra, ProgressMeter, MLUtils
44
using NeuralOperators, Flux, GeometricFlux, Graphs
55
using CUDA, FluxTraining, BSON
6+
using GeometricFlux.GraphSignals: generate_grid
67

78
function circle(n, m; Re = 250) # copy from [WaterLily](https://github.com/weymouth/WaterLily.jl)
89
# Set physical parameters
@@ -31,16 +32,12 @@ function gen_data(ts::AbstractRange)
3132
return 𝐩s
3233
end
3334

34-
function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
35-
ratio::Float64 = 0.95, batchsize = 100, flatten = false)
35+
function get_mno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
36+
ratio::Float64 = 0.95, batchsize = 100)
3637
data = gen_data(ts)
3738
𝐱, 𝐲 = data[:, :, :, 1:(end - 1)], data[:, :, :, 2:end]
3839
n = length(ts) - 1
3940

40-
if flatten
41-
𝐱, 𝐲 = reshape(𝐱, 1, :, n), reshape(𝐲, 1, :, n)
42-
end
43-
4441
data_train, data_test = splitobs(shuffleobs((𝐱, 𝐲)), at = ratio)
4542

4643
loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)
@@ -49,7 +46,7 @@ function get_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
4946
return loader_train, loader_test
5047
end
5148

52-
function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
49+
function train_mno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
5350
if cuda && CUDA.has_cuda()
5451
device = gpu
5552
CUDA.allowscalar(false)
@@ -61,7 +58,7 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
6158

6259
model = MarkovNeuralOperator(ch = (1, 64, 64, 64, 64, 64, 1), modes = (24, 24),
6360
σ = gelu)
64-
data = get_dataloader()
61+
data = get_mno_dataloader()
6562
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))
6663
loss_func = l₂loss
6764

@@ -74,6 +71,32 @@ function train(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
7471
return learner
7572
end
7673

74+
function get_gno_dataloader(; ts::AbstractRange = LinRange(100, 11000, 10000),
75+
ratio::Float64 = 0.95, batchsize = 8)
76+
data = gen_data(ts)
77+
𝐱, 𝐲 = data[:, :, :, 1:(end - 1)], data[:, :, :, 2:end]
78+
n = length(ts) - 1
79+
80+
# generate graph
81+
graph = Graphs.grid(size(data)[2:3])
82+
83+
# add grid coordinates
84+
grid = generate_coordinates(𝐱[1, :, :, 1])
85+
grid = repeat(grid, outer = (1, 1, 1, n))
86+
𝐱 = vcat(𝐱, grid)
87+
88+
# flatten
89+
𝐱, 𝐲 = reshape(𝐱, size(𝐱, 1), :, n), reshape(𝐲, 1, :, n)
90+
91+
fg = FeaturedGraph(graph, nf = 𝐱, pf = 𝐱)
92+
data_train, data_test = splitobs(shuffleobs((fg, 𝐲)), at = ratio)
93+
94+
loader_train = DataLoader(data_train, batchsize = batchsize, shuffle = true)
95+
loader_test = DataLoader(data_test, batchsize = batchsize, shuffle = false)
96+
97+
return loader_train, loader_test
98+
end
99+
77100
function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
78101
if cuda && CUDA.has_cuda()
79102
device = gpu
@@ -84,17 +107,19 @@ function train_gno(; cuda = true, η₀ = 1.0f-3, λ = 1.0f-4, epochs = 50)
84107
@info "Training on CPU"
85108
end
86109

87-
featured_graph = FeaturedGraph(grid([96, 64]))
88-
model = Chain(Dense(1, 16),
89-
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
90-
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
91-
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
92-
WithGraph(featured_graph, GraphKernel(Dense(2 * 16, 16, gelu), 16)),
110+
grid_dim = 2
111+
edge_dim = 2(grid_dim + 1)
112+
model = Chain(GraphParallel(node_layer = Dense(grid_dim + 1, 16)),
113+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
114+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
115+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
116+
GraphKernel(Dense(edge_dim, abs2(16), gelu), 16),
117+
node_feature,
93118
Dense(16, 1))
94-
data = get_dataloader(batchsize = 16, flatten = true)
119+
95120
optimiser = Flux.Optimiser(WeightDecay(λ), Flux.Adam(η₀))
96121
loss_func = l₂loss
97-
122+
data = get_gno_dataloader()
98123
learner = Learner(model, data, optimiser, loss_func,
99124
ToDevice(device, device),
100125
Checkpointer(joinpath(@__DIR__, "../model/")))

src/graph_kernel.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ end
2323

2424
function GraphKernel(κ, ch::Int, σ = identity; init = Flux.glorot_uniform)
2525
W = init(ch, ch)
26+
2627
return GraphKernel(W, κ, σ)
2728
end
2829

@@ -33,6 +34,7 @@ function GeometricFlux.message(l::GraphKernel, x_i, x_j::AbstractArray, e_ij::Ab
3334
K = l.κ(e_ij)
3435
dims = size(K)[2:end]
3536
m_ij = GeometricFlux._matmul(reshape(K, N, N, :), reshape(x_j, N, 1, :))
37+
3638
return reshape(m_ij, N, dims...)
3739
end
3840

@@ -42,13 +44,29 @@ end
4244

4345
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
4446
GraphSignals.check_num_nodes(el.N, X)
45-
GraphSignals.check_num_nodes(el.E, E)
47+
GraphSignals.check_num_edges(el.E, E)
4648
_, V, _ = GeometricFlux.propagate(l, el, E, X, nothing, mean, nothing, nothing)
49+
4750
return V
4851
end
4952

53+
# For variable graph
54+
function (l::GraphKernel)(fg::AbstractFeaturedGraph)
55+
nf = node_feature(fg)
56+
pf = positional_feature(fg)
57+
GraphSignals.check_num_nodes(fg, nf)
58+
GraphSignals.check_num_nodes(fg, pf)
59+
el = GeometricFlux.GraphSignals.to_namedtuple(fg)
60+
61+
# node features + positional features as edge features
62+
ef = vcat(GeometricFlux._gather(pf, el.xs), GeometricFlux._gather(pf, el.nbrs))
63+
_, V, _ = GeometricFlux.propagate(l, el, ef, nf, nothing, mean, nothing, nothing)
64+
65+
return ConcreteFeaturedGraph(fg, nf = V)
66+
end
67+
5068
function Base.show(io::IO, l::GraphKernel)
51-
channel, _ = size(l.linear)
69+
channel = size(l.linear, 1)
5270
print(io, "GraphKernel(", l.κ, ", channel=", channel)
5371
l.σ == identity || print(io, ", ", l.σ)
5472
print(io, ")")

test/graph_kernel.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
@testset "GraphKernel" begin
22
batch_size = 5
3-
channel = 32
3+
channel = 1
44
coord_dim = 2
55
N = 10
66

77
graph = grid([N, N])
8-
κ = Dense(2(coord_dim + 1), abs2(channel), relu)
9-
108
𝐱 = rand(Float32, channel, nv(graph), batch_size)
11-
E = rand(Float32, 2(coord_dim + 1), ne(graph), batch_size)
12-
l = WithGraph(FeaturedGraph(graph), GraphKernel(κ, channel))
13-
@test repr(l.layer) ==
14-
"GraphKernel(Dense($(2(coord_dim + 1)) => $(abs2(channel)), relu), channel=32)"
15-
@test size(l(𝐱, E)) == (channel, nv(graph), batch_size)
9+
κ = Dense(2(coord_dim + channel), abs2(channel), relu)
10+
κ_in_dim, κ_out_dim = 2(coord_dim + channel), abs2(channel)
11+
12+
@testset "layer without graph" begin
13+
pf = rand(Float32, coord_dim, nv(graph), batch_size)
14+
pf = vcat(𝐱, pf)
15+
l = GraphKernel(κ, channel)
16+
fg = FeaturedGraph(graph, nf = 𝐱, pf = pf)
17+
fg_ = l(fg)
18+
@test size(node_feature(fg_)) == (channel, nv(graph), batch_size)
19+
@test_throws MethodError l(𝐱)
1620

17-
g = Zygote.gradient(() -> sum(l(𝐱, E)), Flux.params(l))
18-
@test length(g.grads) == 3
21+
g = gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
22+
@test length(g.grads) == 5
23+
end
1924
end

0 commit comments

Comments
 (0)