Skip to content

Commit 011c956

Browse files
Merge pull request #28 from CarloLucibello/cl/dev
some doc improve
2 parents f6a634b + 7f3cf40 commit 011c956

File tree

9 files changed

+100
-63
lines changed

9 files changed

+100
-63
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
[![codecov](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl)
77

88
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl).
9-
Most relevant features are:
9+
Its most relevant features are:
1010
* Provides CUDA support.
1111
* It's integrated with the JuliaGraphs ecosystem.
1212
* Implements many common graph convolutional layers.

docs/src/gnngraph.md

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,105 @@
11
# Graphs
22

3-
TODO
3+
The fundamental graph type in GraphNeuralNetworks.jl is the [`GNNGraph`](@ref),
4+
A GNNGraph `g` is a directed graph with nodes labeled from 1 to `g.num_nodes`.
5+
The underlying implementation allows for efficient application of graph neural network
6+
operators, gpu movement, and storage of node/edge/graph related feature arrays.
47

58
## Graph Creation
6-
9+
A GNNGraph can be created from several different data sources encoding the graph topology:
710

811
```julia
912
using GraphNeuralNetworks, LightGraphs, SparseArrays
1013

1114

12-
# From LightGraphs's graph
13-
lg_graph = erdos_renyi(10, 0.3)
14-
g = GNNGraph(lg_graph)
15-
15+
# Construct GNNGraph from From LightGraphs's graph
16+
lg = erdos_renyi(10, 30)
17+
g = GNNGraph(lg)
1618

17-
# From adjacency matrix
19+
# From an adjacency matrix
1820
A = sprand(10, 10, 0.3)
19-
2021
g = GNNGraph(A)
2122

22-
@assert adjacency_matrix(g) == A
23-
24-
# From adjacency list
25-
adjlist = [[] [] [] ]
26-
23+
# From an adjacency list
24+
adjlist = [[2,3], [1,3], [1,2,4], [3]]
2725
g = GNNGraph(adjlist)
2826

29-
@assert sort.(adjacency_list(g)) == sort.(adjlist)
30-
3127
# From COO representation
32-
source = []
33-
target = []
28+
source = [1,1,2,2,3,3,3,4]
29+
target = [2,3,1,3,1,2,4,3]
3430
g = GNNGraph(source, target)
35-
@assert edge_index(g) == (source, target)
3631
```
3732

38-
We have also seen some useful methods such as [`adjacency_matrix`](@ref) and [`edge_index`](@ref).
39-
33+
See also the related methods [`adjacency_matrix`](@ref), [`edge_index`](@ref), and [`adjacency_list`](@ref).
4034

4135

4236
## Data Features
4337

4438
```julia
45-
GNNGraph(erods_renyi(10, 30), ndata = (; X=rand(Float32, 32, 10)))
46-
# or equivalently
47-
GNNGraph(sprand(10, 0.3), ndata=rand(Float32, 32, 10))
39+
# Create a graph with a single feature array `x` associated to nodes
40+
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x = rand(Float32, 32, 10)))
41+
# Equivalent definition
42+
g = GNNGraph(erdos_renyi(10, 30), ndata = rand(Float32, 32, 10))
43+
44+
# You can have multiple feature arrays
45+
g = GNNGraph(erdos_renyi(10, 30), ndata = (; x=rand(Float32, 32, 10), y=rand(Float32, 10)))
46+
4847

49-
g = GNNGraph(sprand(10, 0.3), ndata = (X=rand(Float32, 32, 10), y=rand(Float32, 10)))
48+
# Attach an array with edge features
49+
g = GNNGraph(erdos_renyi(10, 30), edata = rand(Float32, 30))
5050

51-
g = GNNGraph(g, edata=rand(Float32, 6, g.num_edges))
51+
# Create a new graph from previous one, inheriting edge data
52+
# but replacing node data
53+
g′ = GNNGraph(g, ndata =(; z = ones(Float32, 16, 10)))
5254
```
5355

5456

5557
## Graph Manipulation
5658

5759
```julia
58-
g = add_self_loops(g)
60+
g = add_self_loops(g)
5961

60-
g = remove_self_loops(g)
62+
g = remove_self_loops(g)
6163
```
6264

6365
## Batches and Subgraphs
6466

6567
```julia
6668
using Flux
6769

68-
gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(3,10)) for _ in 1:100])
70+
gall = Flux.batch([GNNGraph(erdos_renyi(10, 30), ndata=rand(Float32,3,10)) for _ in 1:160])
6971

70-
getgraph(gall, 2:3)
72+
g23 = getgraph(gall, 2:3)
73+
@assert g23.num_graphs == 16
74+
@assert g23.num_nodes == 32
75+
@assert g23.num_edges == 60
7176

7277

7378
# DataLoader compatibility
7479
train_loader = Flux.Data.DataLoader(gall, batchsize=16, shuffle=true)
7580

76-
for g for gall
81+
for g in train_loader
7782
@assert g.num_graphs == 16
7883
@assert g.num_nodes == 160
79-
@assert size(g.ndata.X) = (3, 160)
84+
@assert size(g.ndata.x) = (3, 160)
8085
.....
8186
end
8287
```
8388

84-
## LightGraphs integration
89+
## JuliaGraphs ecosystem integration
90+
91+
Since `GNNGraph <: LightGraphs.AbstractGraph`, we can use any functionality from LightGraphs.
8592

8693
```julia
8794
@assert LightGraphs.isdirected(g)
8895
```
8996

9097
## GPU movement
9198

99+
Move a `GNNGraph` to a CUDA device using `Flux.gpu` method.
100+
92101
```julia
93102
using Flux: gpu
94103

95-
g |> gpu
104+
g_gpu = g |> gpu
96105
```
97-
98-
## Other methods

docs/src/index.md

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,36 @@
11
# GraphNeuralNetworks
22

3-
Documentation for [GraphNeuralNetworks](https://github.com/CarloLucibello/GraphNeuralNetworks.jl).
3+
This is the documentation page for the [GraphNeuralNetworks.jl](https://github.com/CarloLucibello/GraphNeuralNetworks.jl) library.
44

5+
A graph neural network library for Julia based on the deep learning framework [Flux.jl](https://github.com/FluxML/Flux.jl).
6+
Its most relevant features are:
7+
* Provides CUDA support.
8+
* It's integrated with the JuliaGraphs ecosystem.
9+
* Implements many common graph convolutional layers.
10+
* Performs fast operations on batched graphs.
11+
* Makes it easy to define custom graph convolutional layers.
512

6-
## Getting Started
13+
14+
15+
16+
## Package overview
17+
18+
### Data preparation
19+
20+
21+
```
22+
using LightGraphs
23+
24+
lg = LightGraphs.Graph(5) # create a light's graph graph
25+
add_edge!(g, 1, 2)
26+
add_edge!(g, 1, 3)
27+
add_edge!(g, 2, 4)
28+
add_edge!(g, 2, 5)
29+
add_edge!(g, 3, 4)
30+
31+
g = GNNGraph(g)
32+
```
33+
### Model building
34+
35+
### Training
736

docs/src/messagepassing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The message passing corresponds to the following operations
1515
\end{aligned}
1616
```
1717
where ``\phi`` is expressed by the [`compute_message`](@ref) function,
18-
``\gamma_x`` and ``\gamma_v`` by [`update_node`](@ref) and [`update_edge`](@ref)
18+
``\gamma_x`` and ``\gamma_e`` by [`update_node`](@ref) and [`update_edge`](@ref)
1919
respectively.
2020

2121
See [`GraphConv`](ref) and [`GATConv`](ref)'s implementations as usage examples.

examples/graph_classification_tudataset.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function eval_loss_accuracy(model, data_loader, device)
1919
g = g |> device
2020
n = g.num_graphs
2121
y = g.gdata.y
22-
= model(g, g.ndata.X) |> vec
22+
= model(g, g.ndata.x) |> vec
2323
loss += logitbinarycrossentropy(ŷ, y) * n
2424
acc += mean((ŷ .> 0) .== y) * n
2525
ntot += n
@@ -30,16 +30,16 @@ end
3030
function getdataset()
3131
data = TUDataset("MUTAG")
3232

33-
X = Array{Float32}(onehotbatch(data.node_labels, 0:6))
33+
x = Array{Float32}(onehotbatch(data.node_labels, 0:6))
3434
y = (1 .+ Array{Float32}(data.graph_labels)) ./ 2
3535
@assert all(([0,1]), y) # binary classification
3636
# The dataset also has edge features but we won't be using them
37-
E = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
37+
e = Array{Float32}(onehotbatch(data.edge_labels, sort(unique(data.edge_labels))))
3838

3939
return GNNGraph(data.source, data.target,
4040
num_nodes=data.num_nodes,
4141
graph_indicator=data.graph_indicator,
42-
ndata=(; X), edata=(; E), gdata=(; y))
42+
ndata=(; x), edata=(; e), gdata=(; y))
4343
end
4444

4545
# arguments for the `train` function
@@ -83,7 +83,7 @@ function train(; kws...)
8383

8484
# DEFINE MODEL
8585

86-
nin = size(gtrain.ndata.X, 1)
86+
nin = size(gtrain.ndata.x, 1)
8787
nhidden = args.nhidden
8888

8989
model = GNNChain(GraphConv(nin => nhidden, relu),
@@ -111,7 +111,7 @@ function train(; kws...)
111111
for g in train_loader
112112
g = g |> device
113113
gs = Flux.gradient(ps) do
114-
= model(g, g.ndata.X) |> vec
114+
= model(g, g.ndata.x) |> vec
115115
logitbinarycrossentropy(ŷ, g.gdata.y)
116116
end
117117
Flux.Optimise.update!(opt, ps, gs)

examples/node_classification_cora.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function train(; kws...)
6262
ps = Flux.params(model)
6363
opt = ADAM(args.η)
6464

65-
@info "NUM NODES: $(g.num_nodes) NUM EDGES: $(g.num_edges)"
65+
@info g
6666

6767
## LOGGING FUNCTION
6868
function report(epoch)

src/gnngraph.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ end
156156

157157
function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata)
158158

159-
ndata = normalize_graphdata(ndata, :X)
160-
edata = normalize_graphdata(edata, :E)
161-
gdata = normalize_graphdata(gdata, :U)
159+
ndata = normalize_graphdata(ndata, :x)
160+
edata = normalize_graphdata(edata, :e)
161+
gdata = normalize_graphdata(gdata, :u)
162162

163163
GNNGraph(g.graph,
164164
g.num_nodes, g.num_edges, g.num_graphs,

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-
143143
144144
Performs:
145145
```math
146-
\mathbf{x}_i' = W^1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W^2 \mathbf{x}_j)
146+
\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j
147147
```
148148
149149
where the aggregation type is selected by `aggr`.

test/gnngraph.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,25 @@
143143
U = rand(10, g.num_graphs)
144144

145145
g = GNNGraph(g, ndata=X, edata=E, gdata=U)
146-
@test g.ndata.X === X
147-
@test g.edata.E === E
148-
@test g.gdata.U === U
146+
@test g.ndata.x === X
147+
@test g.edata.e === E
148+
@test g.gdata.u === U
149149

150150
# Check no args
151151
g = GNNGraph(g)
152-
@test g.ndata.X === X
153-
@test g.edata.E === E
154-
@test g.gdata.U === U
152+
@test g.ndata.x === X
153+
@test g.edata.e === E
154+
@test g.gdata.u === U
155+
155156

156157
# multiple features names
157-
g = GNNGraph(g, ndata=(x=2X, g.ndata...), edata=(e=2E, g.edata...), gdata=(u=2U, g.gdata...))
158-
@test g.ndata.X === X
159-
@test g.edata.E === E
160-
@test g.gdata.U === U
161-
@test g.ndata.x 2X
162-
@test g.edata.e 2E
163-
@test g.gdata.u 2U
158+
g = GNNGraph(g, ndata=(x2=2X, g.ndata...), edata=(e2=2E, g.edata...), gdata=(u2=2U, g.gdata...))
159+
@test g.ndata.x === X
160+
@test g.edata.e === E
161+
@test g.gdata.u === U
162+
@test g.ndata.x2 2X
163+
@test g.edata.e2 2E
164+
@test g.gdata.u2 2U
164165
end
165166

166167
@testset "LearnBase and DataLoader compat" begin

0 commit comments

Comments
 (0)