Skip to content

Commit 375b787

Browse files
GNNGraph inherits from Graphs.AbstractGraph (#63)
* improve GNNGraph docstring * inherit from AbstractGraph * cleanup
1 parent f647c1e commit 375b787

File tree

5 files changed

+33
-15
lines changed

5 files changed

+33
-15
lines changed

docs/src/api/gnngraph.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ CurrentModule = GraphNeuralNetworks
66

77
Documentation page for the graph type `GNNGraph` provided GraphNeuralNetworks.jl and its related methods.
88

9+
Besides the methods documented here, one can rely on the large set of functionalities
10+
given by [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl)
11+
since `GNNGraph` inherits from `Graphs.AbstractGraph`.
12+
913
## Index
1014

1115
```@index

docs/src/gnngraph.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ A GNNGraph `g` is a directed graph with nodes labeled from 1 to `g.num_nodes`.
55
The underlying implementation allows for efficient application of graph neural network
66
operators, gpu movement, and storage of node/edge/graph related feature arrays.
77

8+
`GNNGraph` inherits from [Graphs.jl](https://github.com/JuliaGraphs/Graphs.jl)'s `AbstractGraph`,
9+
therefore it supports most functionality from that library.
10+
811
## Graph Creation
912
A GNNGraph can be created from several different data sources encoding the graph topology:
1013

src/gnngraph.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,34 @@ const ADJMAT_T = AbstractMatrix
1111
const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212
const CUMAT_T = Union{AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
1313

14-
"""
14+
"""
1515
GNNGraph(data; [graph_type, ndata, edata, gdata, num_nodes, graph_indicator, dir])
1616
GNNGraph(g::GNNGraph; [ndata, edata, gdata])
1717
18-
A type representing a graph structure and storing also
19-
feature arrays associated to nodes, edges, and to the whole graph (global features).
18+
A type representing a graph structure that also stores
19+
feature arrays associated to nodes, edges, and the graph itself.
2020
21-
A `GNNGraph` can be constructed out of different objects `data` expressing
22-
the connections inside the graph. The internal representation type
21+
A `GNNGraph` can be constructed out of different `data` objects
22+
expressing the connections inside the graph. The internal representation type
2323
is determined by `graph_type`.
2424
2525
When constructed from another `GNNGraph`, the internal graph representation
26-
is preserved and shared. The node/edge/global features are transmitted
27-
as well, unless explicitely changed though keyword arguments.
26+
is preserved and shared. The node/edge/graph features are retained
27+
as well, unless explicitely set by the keyword arguments
28+
`ndata`, `edata`, and `gdata`.
2829
2930
A `GNNGraph` can also represent multiple graphs batched togheter
3031
(see [`Flux.batch`](@ref) or [`SparseArrays.blockdiag`](@ref)).
3132
The field `g.graph_indicator` contains the graph membership
3233
of each node.
3334
34-
A `GNNGraph` is a Graphs' `AbstractGraph`, therefore any functionality
35-
from the Graphs' graph library can be used on it.
35+
`GNNGraph`s are always directed graphs, therefore each edge is defined
36+
by a source node and a target node (see [`edge_index`](@ref)).
37+
Self loops (edges connecting a node to itself) and multiple edges
38+
(more than one edge between the same pair of nodes) are supported.
39+
40+
A `GNNGraph` is a Graphs.jl's `AbstractGraph`, therefore it supports most
41+
functionality from that library.
3642
3743
# Arguments
3844
@@ -54,9 +60,9 @@ from the Graphs' graph library can be used on it.
5460
Possible values are `:out` and `:in`. Default `:out`.
5561
- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`.
5662
- `graph_indicator`: For batched graphs, a vector containing the graph assigment of each node. Default `nothing`.
57-
- `ndata`: Node features. A named tuple of arrays whose last dimension has size num_nodes.
58-
- `edata`: Edge features. A named tuple of arrays whose whose last dimension has size num_edges.
59-
- `gdata`: Global features. A named tuple of arrays whose has size num_graphs.
63+
- `ndata`: Node features. A named tuple of arrays whose last dimension has size `num_nodes`.
64+
- `edata`: Edge features. A named tuple of arrays whose last dimension has size `num_edges`.
65+
- `gdata`: Graph features. A named tuple of arrays whose last dimension has size `num_graphs`.
6066
6167
# Usage.
6268
@@ -97,7 +103,7 @@ g = g |> gpu
97103
source, target = edge_index(g)
98104
```
99105
"""
100-
struct GNNGraph{T<:Union{COO_T,ADJMAT_T}}
106+
struct GNNGraph{T<:Union{COO_T,ADJMAT_T}} <: AbstractGraph{Int}
101107
graph::T
102108
num_nodes::Int
103109
num_edges::Int

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
3939

4040
if n == 1
4141
# If last array dimension is not 1, add a new dimension.
42-
# This is mostly usefule to reshape globale feature vectors
43-
# of size D to Dx1 matrices.
42+
# This is mostly useful to reshape graph feature vectors
43+
# of size D into Dx1 matrices.
4444
function unsqz(v)
4545
if v isa AbstractArray && size(v)[end] != 1
4646
v = reshape(v, size(v)..., 1)

test/gnngraph.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,11 @@
277277
d = Flux.Data.DataLoader(g, batchsize = 2, shuffle=false)
278278
@test first(d) == getgraph(g, 1:2)
279279
end
280+
281+
@testset "Graphs.jl integration" begin
282+
g = GNNGraph(erdos_renyi(10, 20))
283+
@test g isa Graphs.AbstractGraph
284+
end
280285
end
281286

282287

0 commit comments

Comments
 (0)