Skip to content

Commit ebdb73b

Browse files
neural ode example working on cpu and gpu
1 parent 47a17ce commit ebdb73b

File tree

4 files changed

+40
-26
lines changed

4 files changed

+40
-26
lines changed

examples/neural_ode.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Load the packages
2-
using GraphNeuralNetworks, JLD2, DiffEqFlux, DifferentialEquations
3-
using Flux: onehotbatch, onecold, throttle
2+
using GraphNeuralNetworks, DiffEqFlux, DifferentialEquations
3+
using Flux: onehotbatch, onecold
44
using Flux.Losses: logitcrossentropy
55
using Statistics: mean
66
using MLDatasets: Cora
77

8-
device = cpu # `gpu` not working yet
8+
# device = cpu # `gpu` not working yet
9+
device = gpu
910

1011
# LOAD DATA
1112
data = Cora.dataset()
@@ -39,21 +40,21 @@ node = NeuralODE(WithGraph(node_chain, g),
3940
model = GNNChain(GCNConv(nin => nhidden, relu),
4041
Dropout(0.5),
4142
node,
42-
diffeqarray_to_array,
43+
diffeqsol_to_array,
4344
Dense(nhidden, nout)) |> device
4445

4546
# Loss
4647
loss(x, y) = logitcrossentropy(model(g, x), y)
4748
accuracy(x, y) = mean(onecold(model(g, x)) .== onecold(y))
4849

49-
# Training
50-
## Model Parameters
51-
ps = Flux.params(model, node.p);
50+
# # Training
51+
# ## Model Parameters
52+
ps = Flux.params(model);
5253

53-
## Optimizer
54+
# ## Optimizer
5455
opt = ADAM(0.01)
5556

56-
## Training Loop
57+
# ## Training Loop
5758
for epoch in 1:epochs
5859
gs = gradient(() -> loss(X, y), ps)
5960
Flux.Optimise.update!(opt, ps, gs)

src/GNNGraphs/gnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ https://juliagraphs.org/Graphs.jl/latest/types/#AbstractGraph-Type
55
https://juliagraphs.org/Graphs.jl/latest/developing/#Developing-Alternate-Graph-Types
66
=============================================#
77

8-
const COO_T = Tuple{T, T, V} where {T <: AbstractVector, V}
9-
const ADJLIST_T = AbstractVector{T} where T <: AbstractVector
8+
const COO_T = Tuple{T, T, V} where {T <: AbstractVector{<:Integer}, V}
9+
const ADJLIST_T = AbstractVector{T} where T <: AbstractVector{<:Integer}
1010
const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}

src/layers/basic.jl

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ abstract type GNNLayer end
1313

1414

1515
"""
16-
WithGraph(model, g::GNNGraph; traingraph=false)
16+
WithGraph(model, g::GNNGraph; traingraph=false)
1717
1818
A type wrapping the `model` and tying it to the graph `g`.
1919
In the forward pass, can only take feature arrays as inputs,
@@ -38,17 +38,31 @@ x2 = rand(Float32, 2, 4)
3838
@assert wg(g2, x2) == model(g2, x2)
3939
```
4040
"""
41-
struct WithGraph{M}
42-
model::M
43-
g::GNNGraph
44-
traingraph::Bool
41+
struct WithGraph{M, G<:GNNGraph}
42+
model::M
43+
g::G
44+
traingraph::Bool
4545
end
4646

4747
WithGraph(model, g::GNNGraph; traingraph=false) = WithGraph(model, g, traingraph)
4848

4949
@functor WithGraph
5050
Flux.trainable(l::WithGraph) = l.traingraph ? (l.model, l.g) : (l.model,)
5151

52+
# Work around
53+
# https://github.com/FluxML/Flux.jl/issues/1733
54+
# Revisit after
55+
# https://github.com/FluxML/Flux.jl/pull/1742
56+
function Flux.destructure(m::WithGraph)
57+
@assert m.traingraph == false # TODO
58+
p, re = Flux.destructure(m.model)
59+
function re_withgraph(x)
60+
WithGraph(re(x), m.g, m.traingraph)
61+
end
62+
63+
return p, re_withgraph
64+
end
65+
5266
(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
5367
(l::WithGraph)(x...; kws...) = l.model(l.g, x...; kws...)
5468

@@ -86,15 +100,15 @@ julia> m(g, x)
86100
```
87101
"""
88102
struct GNNChain{T} <: GNNLayer
89-
layers::T
90-
91-
GNNChain(xs...) = new{typeof(xs)}(xs)
92-
93-
function GNNChain(; kw...)
94-
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
95-
isempty(kw) && return new{Tuple{}}(())
96-
new{typeof(values(kw))}(values(kw))
97-
end
103+
layers::T
104+
105+
GNNChain(xs...) = new{typeof(xs)}(xs)
106+
107+
function GNNChain(; kw...)
108+
:layers in Base.keys(kw) && throw(ArgumentError("a GNNChain cannot have a named layer called `layers`"))
109+
isempty(kw) && return new{Tuple{}}(())
110+
new{typeof(values(kw))}(values(kw))
111+
end
98112
end
99113

100114
@forward GNNChain.layers Base.getindex, Base.length, Base.first, Base.last,

test/examples/node_classification_cora.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ function eval_loss_accuracy(X, y, ids, model, g)
1414
return (loss = round(l, digits=4), acc = round(acc*100, digits=2))
1515
end
1616

17-
1817
# arguments for the `train` function
1918
Base.@kwdef mutable struct Args
2019
η = 5f-3 # learning rate

0 commit comments

Comments
 (0)