Skip to content

Commit f8a2695

Browse files
refactor GNNGraph into its own module; implement add_edges
1 parent 375b787 commit f8a2695

File tree

16 files changed

+965
-188
lines changed

16 files changed

+965
-188
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
12+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1213
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1314
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
1415
LearnBase = "7f8f8fb0-2700-5f03-b4bd-41f8cfc144b6"
@@ -17,6 +18,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1718
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1819
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1920
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
21+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2022
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2123
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2224
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
@@ -28,12 +30,14 @@ CUDA = "3.3"
2830
ChainRulesCore = "1"
2931
DataStructures = "0.18"
3032
Flux = "0.12.7"
33+
Functors = "0.2"
3134
Graphs = "1.4"
3235
KrylovKit = "0.5"
3336
LearnBase = "0.4, 0.5"
3437
MacroTools = "0.5"
3538
NNlib = "0.7"
3639
NNlibCUDA = "0.1"
40+
Reexport = "1"
3741
StatsBase = "0.32, 0.33"
3842
TestEnv = "1"
3943
julia = "1.6"

docs/src/api/gnngraph.md

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,52 @@ Pages = ["gnngraph.md"]
1919

2020
## Docs
2121

22+
### GNNGraph
23+
24+
```@autodocs
25+
Modules = [GraphNeuralNetworks]
26+
Pages = ["GNNGraphs/gnngraph.jl"]
27+
Private = false
28+
```
29+
30+
### Query
31+
32+
```@autodocs
33+
Modules = [GraphNeuralNetworks]
34+
Pages = ["GNNGraphs/query.jl"]
35+
Private = false
36+
```
37+
38+
```@docs
39+
Graphs.adjacency_matrix
40+
Graphs.degree
41+
Graphs.outneighbors
42+
Graphs.inneighbors
43+
```
44+
45+
### Transform
46+
2247
```@autodocs
2348
Modules = [GraphNeuralNetworks]
24-
Pages = ["gnngraph.jl"]
49+
Pages = ["GNNGraphs/transform.jl"]
2550
Private = false
2651
```
2752

2853
```@docs
2954
Flux.batch
3055
SparseArrays.blockdiag
31-
Graphs.adjacency_matrix
56+
```
57+
58+
### Generate
59+
60+
```@autodocs
61+
Modules = [GraphNeuralNetworks]
62+
Pages = ["GNNGraphs/generate.jl"]
63+
Private = false
64+
```
65+
66+
### Related methods
67+
68+
```@docs
69+
SparseArrays.sparse
3270
```

docs/src/gnngraph.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ A GNNGraph can be created from several different data sources encoding the graph
1515
using GraphNeuralNetworks, Graphs, SparseArrays
1616

1717

18-
# Construct GNNGraph from From Graphs's graph
18+
# Construct a GNNGraph from from a Graphs.jl's graph
1919
lg = erdos_renyi(10, 30)
2020
g = GNNGraph(lg)
2121

22+
# Same as above using convenience method rand_graph
23+
g = rand_graph(10, 30)
24+
2225
# From an adjacency matrix
2326
A = sprand(10, 10, 0.3)
2427
g = GNNGraph(A)
@@ -123,21 +126,21 @@ for g in train_loader
123126
.....
124127
end
125128

126-
# Access the nodes' graph memberships through
127-
gall.graph_indicator
129+
# Access the nodes' graph memberships
130+
graph_indicator(gall)
128131
```
129132

130133
## Graph Manipulation
131134

132135
```julia
133136
g′ = add_self_loops(g)
134-
135137
g′ = remove_self_loops(g)
138+
g′ = add_edges(g, [1, 2], [2, 3]) # add edges 1->2 and 2->3
136139
```
137140

138141
## JuliaGraphs ecosystem integration
139142

140-
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.
143+
Since `GNNGraph <: Graphs.AbstractGraph`, we can use any functionality from Graphs.jl.
141144

142145
```julia
143146
@assert Graphs.isdirected(g)

src/GNNGraphs/GNNGraphs.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
module GNNGraphs
2+
3+
using SparseArrays
4+
using Functors: @functor
5+
using CUDA
6+
import Graphs
7+
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree
8+
import Flux
9+
using Flux: batch
10+
import NNlib
11+
import LearnBase
12+
import StatsBase
13+
using LearnBase: getobs
14+
import KrylovKit
15+
using ChainRulesCore
16+
using LinearAlgebra, Random
17+
18+
include("gnngraph.jl")
19+
export GNNGraph, node_features, edge_features, graph_features
20+
21+
include("query.jl")
22+
export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
23+
graph_indicator
24+
25+
include("transform.jl")
26+
export add_edges, add_self_loops, remove_self_loops, getgraph
27+
28+
include("generate.jl")
29+
export rand_graph
30+
31+
32+
include("convert.jl")
33+
include("utils.jl")
34+
35+
export
36+
# from Graphs
37+
adjacency_matrix, degree, outneighbors, inneighbors,
38+
# from SparseArrays
39+
sprand, sparse, blockdiag,
40+
# from Flux
41+
batch
42+
43+
end #module
File renamed without changes.

src/GNNGraphs/generate.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
rand_graph(n, m; directed=false, kws...)
3+
4+
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes.
5+
6+
If `directed=false` the output will contain `2m` edges:
7+
the reverse edge of each edge will be present.
8+
If `directed=true` instead, `m` unrelated edges are generated.
9+
10+
Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor.
11+
"""
12+
function rand_graph(n::Integer, m::Integer; directed=false, kws...)
13+
return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...)
14+
end

0 commit comments

Comments
 (0)