Skip to content

Commit 195bb6c

Browse files
Merge pull request #64 from CarloLucibello/cl/grah
refactor GNNGraph into submodule + implement add_edges
2 parents 375b787 + df874da commit 195bb6c

18 files changed

+973
-814
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
12+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1213
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1314
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1415
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
@@ -17,6 +18,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1718
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1819
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2022
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2123
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2224
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -28,12 +30,14 @@ CUDA = "3.3"
2830
ChainRulesCore = "1"
2931
DataStructures = "0.18"
3032
Flux = "0.12.7"
33+
Functors = "0.2"
3134
Graphs = "1.4"
3235
KrylovKit = "0.5"
3336
LearnBase = "0.4, 0.5"
3437
MacroTools = "0.5"
3538
NNlib = "0.7"
3639
NNlibCUDA = "0.1"
40+
Reexport = "1"
3741
StatsBase = "0.32, 0.33"
3842
TestEnv = "1"
3943
julia = "1.6"

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Flux, NNlib, GraphNeuralNetworks, Graphs, SparseArrays
22
using Documenter
33

4-
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup, :(using GraphNeuralNetworks); recursive=true)
4+
DocMeta.setdocmeta!(GraphNeuralNetworks, :DocTestSetup,
5+
:(using GraphNeuralNetworks, Graphs, SparseArrays, NNlib, Flux);
6+
recursive=true)
57

68
makedocs(;
79
modules=[GraphNeuralNetworks, NNlib, Flux, Graphs, SparseArrays],

docs/src/api/gnngraph.md

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@ CurrentModule = GraphNeuralNetworks
44

55
# GNNGraph
66

7-
Documentation page for the graph type `GNNGraph` provided GraphNeuralNetworks.jl and its related methods.
7+
Documentation page for the graph type `GNNGraph` provided by GraphNeuralNetworks.jl and related methods.
8+
9+
```@contents
10+
Pages = ["gnngraph.md"]
11+
Depth = 5
12+
```
813

914
Besides the methods documented here, one can rely on the large set of functionalities
1015
given by [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl)
@@ -17,16 +22,41 @@ Order = [:type, :function]
1722
Pages = ["gnngraph.md"]
1823
```
1924

20-
## Docs
25+
## GNNGraph type
2126

2227
```@autodocs
23-
Modules = [GraphNeuralNetworks]
28+
Modules = [GraphNeuralNetworks.GNNGraphs]
2429
Pages = ["gnngraph.jl"]
2530
Private = false
2631
```
2732

33+
## Query
34+
35+
```@autodocs
36+
Modules = [GraphNeuralNetworks.GNNGraphs]
37+
Pages = ["query.jl"]
38+
Private = false
39+
```
40+
2841
```@docs
29-
Flux.batch
30-
SparseArrays.blockdiag
3142
Graphs.adjacency_matrix
43+
Graphs.degree
44+
Graphs.outneighbors
45+
Graphs.inneighbors
46+
```
47+
48+
## Transform
49+
50+
```@autodocs
51+
Modules = [GraphNeuralNetworks.GNNGraphs]
52+
Pages = ["transform.jl"]
53+
Private = false
54+
```
55+
56+
## Generate
57+
58+
```@autodocs
59+
Modules = [GraphNeuralNetworks.GNNGraphs]
60+
Pages = ["generate.jl"]
61+
Private = false
3262
```

docs/src/gnngraph.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ A GNNGraph can be created from several different data sources encoding the graph
1515
using GraphNeuralNetworks, Graphs, SparseArrays
1616

1717

18-
# Construct GNNGraph from From Graphs's graph
18+
# Construct a GNNGraph from from a Graphs.jl's graph
1919
lg = erdos_renyi(10, 30)
2020
g = GNNGraph(lg)
2121

22+
# Same as above using convenience method rand_graph
23+
g = rand_graph(10, 30)
24+
2225
# From an adjacency matrix
2326
A = sprand(10, 10, 0.3)
2427
g = GNNGraph(A)
@@ -33,7 +36,7 @@ target = [2,3,1,3,1,2,4,3]
3336
g = GNNGraph(source, target)
3437
```
3538

36-
See also the related methods [`adjacency_matrix`](@ref), [`edge_index`](@ref), and [`adjacency_list`](@ref).
39+
See also the related methods [`Graphs.adjacency_matrix`](@ref), [`edge_index`](@ref), and [`adjacency_list`](@ref).
3740

3841
## Basic Queries
3942

@@ -123,21 +126,21 @@ for g in train_loader
123126
.....
124127
end
125128

126-
# Access the nodes' graph memberships through
127-
gall.graph_indicator
129+
# Access the nodes' graph memberships
130+
graph_indicator(gall)
128131
```
129132

130133
## Graph Manipulation
131134

132135
```julia
133136
g′ = add_self_loops(g)
134-
135137
g′ = remove_self_loops(g)
138+
g′ = add_edges(g, [1, 2], [2, 3]) # add edges 1->2 and 2->3
136139
```
137140

138141
## JuliaGraphs ecosystem integration
139142

140-
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.
143+
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.jl.
141144

142145
```julia
143146
@assert Graphs.isdirected(g)

src/GNNGraphs/GNNGraphs.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
module GNNGraphs
2+
3+
using SparseArrays
4+
using Functors: @functor
5+
using CUDA
6+
import Graphs
7+
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
8+
import Flux
9+
using Flux: batch
10+
import NNlib
11+
import LearnBase
12+
import StatsBase
13+
using LearnBase: getobs
14+
import KrylovKit
15+
using ChainRulesCore
16+
using LinearAlgebra, Random
17+
18+
include("gnngraph.jl")
19+
export GNNGraph, node_features, edge_features, graph_features
20+
21+
include("query.jl")
22+
export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
23+
graph_indicator
24+
25+
include("transform.jl")
26+
export add_edges, add_self_loops, remove_self_loops, getgraph
27+
28+
include("generate.jl")
29+
export rand_graph
30+
31+
32+
include("convert.jl")
33+
include("utils.jl")
34+
35+
export
36+
# from Graphs
37+
adjacency_matrix, degree, outneighbors, inneighbors,
38+
# from SparseArrays
39+
sprand, sparse, blockdiag,
40+
# from Flux
41+
batch
42+
43+
end #module
File renamed without changes.

src/GNNGraphs/generate.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
rand_graph(n, m; directed=false, kws...)
3+
4+
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes.
5+
6+
If `directed=false` the output will contain `2m` edges:
7+
the reverse edge of each edge will be present.
8+
If `directed=true` instead, `m` unrelated edges are generated.
9+
10+
Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor.
11+
"""
12+
function rand_graph(n::Integer, m::Integer; directed=false, kws...)
13+
return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...)
14+
end

0 commit comments

Comments
 (0)