Skip to content

Commit 00aade3

Browse files
Merge pull request #67 from CarloLucibello/cl/dev
NeuralODE example working on cpu and gpu
2 parents da7b1e4 + dfcb468 commit 00aade3

File tree

6 files changed

+53
-30
lines changed

6 files changed

+53
-30
lines changed

examples/neural_ode.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# Load the packages
2-
using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations
3-
using Flux: onehotbatch, onecold, throttle
2+
using GraphNeuralNetworks, DiffEqFlux, DifferentialEquations
3+
using Flux: onehotbatch, onecold
44
using Flux.Losses: logitcrossentropy
55
using Statistics: mean
66
using MLDatasets: Cora
7+
using CUDA
8+
# CUDA.allowscalar(false) # Some scalar indexing is still done by DiffEqFlux
79

8-
device = cpu # `gpu` not working yet
10+
# device = cpu # `gpu` not working yet
11+
device = CUDA.functional() ? gpu : cpu
912

1013
# LOAD DATA
1114
data = Cora.dataset()
@@ -39,21 +42,21 @@ node = NeuralODE(WithGraph(node_chain, g),
3942
model = GNNChain(GCNConv(nin => nhidden, relu),
4043
Dropout(0.5),
4144
node,
42-
diffeqarray_to_array,
45+
diffeqsol_to_array,
4346
Dense(nhidden, nout)) |> device
4447

4548
# Loss
4649
loss(x, y) = logitcrossentropy(model(g, x), y)
4750
accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y))
4851

49-
# Training
50-
## Model Parameters
51-
ps = Flux.params(model, node.p);
52+
# # Training
53+
# ## Model Parameters
54+
ps = Flux.params(model);
5255

53-
## Optimizer
56+
# ## Optimizer
5457
opt = ADAM(0.01)
5558

56-
## Training Loop
59+
# ## Training Loop
5760
for epoch in 1:epochs
5861
gs = gradient(() -> loss(X, y), ps)
5962
Flux.Optimise.update!(opt, ps, gs)

src/GNNGraphs/generate.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
rand_graph(n, m; bidirected=true, kws...)
2+
rand_graph(n, m; bidirected=true, seed=-1, kws...)
33
44
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes
55
and `m` edges.
@@ -8,6 +8,8 @@ If `bidirected=true` the reverse edge of each edge will be present.
88
If `bidirected=false` instead, `m` unrelated edges are generated.
99
In any case, the output graph will contain no self-loops or multi-edges.
1010
11+
Use a `seed > 0` for reproducibility.
12+
1113
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
1214
1315
# Usage
@@ -43,10 +45,10 @@ julia> edge_index(g)
4345
4446
```
4547
"""
46-
function rand_graph(n::Integer, m::Integer; bidirected=true, kws...)
48+
function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
4749
if bidirected
4850
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
4951
end
5052
m2 = bidirected ? m÷2 : m
51-
return GNNGraph(Graphs.erdos_renyi(n, m2, is_directed=!bidirected); kws...)
53+
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed=!bidirected, seed); kws...)
5254
end

src/GNNGraphs/gnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
55
https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types
66
=============================================#
77

8-
const COO_T = Tuple{T, T, V} where {T <: AbstractVector, V}
9-
const ADJLIST_T = AbstractVector{T} where T <: AbstractVector
8+
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
9+
const ADJLIST_T = AbstractVector{T} where T <: AbstractVector{<:Integer}
1010
const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}

src/layers/basic.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ abstract type GNNLayer end
1313

1414

1515
"""
16-
WithGraph(model, g::GNNGraph; traingraph=false)
16+
WithGraph(model, g::GNNGraph; traingraph=false)
1717
1818
A type wrapping the `model` and tying it to the graph `g`.
1919
In the forward pass, can only take feature arrays as inputs,
@@ -38,17 +38,31 @@ x2 = rand(Float32, 2, 4)
3838
@assert wg(g2, x2) == model(g2, x2)
3939
```
4040
"""
41-
struct WithGraph{M}
42-
model::M
43-
g::GNNGraph
44-
traingraph::Bool
41+
struct WithGraph{M, G<:GNNGraph}
42+
model::M
43+
g::G
44+
traingraph::Bool
4545
end
4646

4747
WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph)
4848

4949
@functor WithGraph
5050
Flux.trainable(l::WithGraph) = l.traingraph ? (l.model, l.g) : (l.model,)
5151

52+
# Work around
53+
# https://github.com/FluxML/Flux.jl/issues/1733
54+
# Revisit after
55+
# https://github.com/FluxML/Flux.jl/pull/1742
56+
function Flux.destructure(m::WithGraph)
57+
@assert m.traingraph == false # TODO
58+
p, re = Flux.destructure(m.model)
59+
function re_withgraph(x)
60+
WithGraph(re(x), m.g, m.traingraph)
61+
end
62+
63+
return p, re_withgraph
64+
end
65+
5266
(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
5367
(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...)
5468

@@ -86,15 +100,15 @@ julia> m(g, x)
86100
```
87101
"""
88102
struct GNNChain{T} <: GNNLayer
89-
layers::T
90-
91-
GNNChain(xs...) = new{typeof(xs)}(xs)
92-
93-
function GNNChain(; kw...)
94-
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
95-
isempty(kw) && return new{Tuple{}}(())
96-
new{typeof(values(kw))}(values(kw))
97-
end
103+
layers::T
104+
105+
GNNChain(xs...) = new{typeof(xs)}(xs)
106+
107+
function GNNChain(; kw...)
108+
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
109+
isempty(kw) && return new{Tuple{}}(())
110+
new{typeof(values(kw))}(values(kw))
111+
end
98112
end
99113

100114
@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last,

test/GNNGraphs/generate.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
m2 = m ÷ 2
55
x = rand(3, n)
66
e = rand(4, m2)
7+
78
g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T)
89
@test g.num_nodes == n
910
@test g.num_edges == m
@@ -15,8 +16,12 @@
1516
@test g.edata.e[:,1:m2] == e
1617
@test g.edata.e[:,m2+1:end] == e
1718
end
18-
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
19+
20+
g = rand_graph(n, m, bidirected=false, seed=17, graph_type=GRAPH_T)
1921
@test g.num_nodes == n
2022
@test g.num_edges == m
23+
24+
g2 = rand_graph(n, m, bidirected=false, seed=17, graph_type=GRAPH_T)
25+
@test edge_index(g2) == edge_index(g)
2126
end
2227
end

test/examples/node_classification_cora.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ function eval_loss_accuracy(X, y, ids, model, g)
1414
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
1515
end
1616

17-
1817
# arguments for the `train` function
1918
Base.@kwdef mutable struct Args
2019
η = 5f-3 # learning rate

0 commit comments

Comments
 (0)