Skip to content

Commit 00e4f5a

Browse files
layers do not store graphs internally
1 parent 14ab918 commit 00e4f5a

File tree

8 files changed

+220
-598
lines changed

8 files changed

+220
-598
lines changed

README.md

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
![](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/actions/workflows/ci.yml/badge.svg)
66
[![codecov](https://codecov.io/gh/FluxML/GraphNeuralNetworks.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/CarloLucibello/GraphNeuralNetworks.jl)
77

8-
GraphNeuralNetworks is a geometric deep learning library for [Flux](https://github.com/FluxML/Flux.jl). This library aims to be compatible with packages from [JuliaGraphs](https://github.com/JuliaGraphs) ecosystem and have support of CUDA GPU acceleration with [CUDA](https://github.com/JuliaGPU/CUDA.jl). Message passing scheme is implemented as a flexbile framework and fused with Graph Network block scheme. GraphNeuralNetworks is compatible with other packages that are composable with Flux.
8+
GraphNeuralNetworks (GNN) is a graph neural network library for Julia based on the [Flux.jl](https://github.com/FluxML/Flux.jl) deep learning framework.
99

10-
Suggestions, issues and pull requsts are welcome.
1110

1211
## Installation
1312

@@ -17,15 +16,15 @@ Suggestions, issues and pull requsts are welcome.
1716

1817
## Features
1918

20-
* Extend Flux deep learning framework in Julia and compatible with Flux layers.
21-
* Support of CUDA GPU with CUDA.jl
22-
* Integrate with existing JuliaGraphs ecosystem
23-
* Support generic graph neural network architectures
24-
* Variable graph inputs are supported. You use it when diverse graph structures are prepared as inputs to the same model.
19+
* Based on the Flux deep learning framework.
20+
* CUDA support.
21+
* Integrated with the JuliaGraphs ecosystem.
22+
* Supports generic graph neural network architectures.
23+
* Easy to define custom graph convolutional layers.
2524

2625
## Featured Graphs
2726

28-
GraphNeuralNetworks handles graph data (the topology plus node/vertex/graph features)
27+
GraphNeuralNetworks handles graph data (the graph topology + node/edge/global features)
2928
thanks to the type `FeaturedGraph`.
3029

3130
A `FeaturedGraph` can be constructed out of
@@ -34,26 +33,45 @@ adjacency matrices, adjacency lists, LightGraphs' types...
3433
```julia
3534
fg = FeaturedGraph(adj_list)
3635
```
36+
3737
## Graph convolutional layers
3838

3939
Construct a GCN layer:
4040

4141
```julia
42-
GCNConv([fg,] input_dim => output_dim, relu)
42+
GCNConv(input_dim => output_dim, relu)
4343
```
4444

45-
## Use it as you use Flux
45+
## Usage Example
4646

4747
```julia
48-
model = Chain(GCNConv(fg, 1024 => 512, relu),
49-
Dropout(0.5),
50-
GCNConv(fg, 512 => 128),
51-
Dense(128, 10))
52-
## Loss
53-
loss(x, y) = logitcrossentropy(model(x), y)
54-
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
55-
56-
## Training
48+
struct GNN
49+
conv1
50+
conv2
51+
dense
52+
53+
function GNN()
54+
new(GCNConv(1024=>512, relu),
55+
GCNConv(512=>128, relu),
56+
Dense(128, 10))
57+
end
58+
end
59+
60+
@functor GNN
61+
62+
function (net::GNN)(g, x)
63+
x = net.conv1(g, x)
64+
x = dropout(x, 0.5)
65+
x = net.conv2(g, x)
66+
x = net.dense(x)
67+
return x
68+
end
69+
70+
model = GNN()
71+
72+
loss(x, y) = logitcrossentropy(model(fg, x), y)
73+
accuracy(x, y) = mean(onecold(model(fg, x)) .== onecold(y))
74+
5775
ps = Flux.params(model)
5876
train_data = [(train_X, train_y)]
5977
opt = ADAM(0.01)

src/GraphNeuralNetworks.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ include("layers/msgpass.jl")
6464

6565
include("layers/conv.jl")
6666
include("layers/pool.jl")
67-
include("models.jl")
6867
include("layers/misc.jl")
6968

7069

src/featuredgraph.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,7 @@ https://juliagraphs.org/LightGraphs.jl/latest/types/#AbstractGraph-Type
55
https://juliagraphs.org/LightGraphs.jl/latest/developing/#Developing-Alternate-Graph-Types
66
=============================================#
77

8-
abstract type AbstractFeaturedGraph <: AbstractGraph{Int} end
9-
10-
"""
11-
NullGraph()
12-
13-
Null object for `FeaturedGraph`.
14-
"""
15-
struct NullGraph <: AbstractFeaturedGraph end
16-
17-
const COO_T = Tuple{T, T, V} where {T <: AbstractVector,V}
8+
const COO_T = Tuple{T, T, V} where {T <: AbstractVector, V}
189
const ADJLIST_T = AbstractVector{T} where T <: AbstractVector
1910
const ADJMAT_T = AbstractMatrix
2011
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
@@ -93,7 +84,7 @@ source, target = edge_index(fg)
9384
9485
See also [`graph`](@ref), [`edge_index`](@ref), [`node_feature`](@ref), [`edge_feature`](@ref), and [`global_feature`](@ref)
9586
"""
96-
struct FeaturedGraph{T<:Union{COO_T,ADJMAT_T}} <: AbstractFeaturedGraph
87+
struct FeaturedGraph{T<:Union{COO_T,ADJMAT_T}}
9788
graph::T
9889
num_nodes::Int
9990
num_edges::Int

0 commit comments

Comments
 (0)