Skip to content

Commit da7b1e4

Browse files
implement unbatch and add_nodes (#65)
* add unbatch * implement add_nodes * cleanup
1 parent 195bb6c commit da7b1e4

File tree

8 files changed

+209
-28
lines changed

8 files changed

+209
-28
lines changed

docs/src/gnngraph.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ lg = erdos_renyi(10, 30)
2020
g = GNNGraph(lg)
2121

2222
# Same as above using convenience method rand_graph
23-
g = rand_graph(10, 30)
23+
g = rand_graph(10, 60)
2424

2525
# From an adjacency matrix
2626
A = sprand(10, 10, 0.3)

src/GNNGraphs/GNNGraphs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ export GNNGraph, node_features, edge_features, graph_features
2121
include("query.jl")
2222
export edge_index, adjacency_list, normalized_laplacian, scaled_laplacian,
2323
graph_indicator
24-
24+
2525
include("transform.jl")
26-
export add_edges, add_self_loops, remove_self_loops, getgraph
26+
export add_nodes, add_edges, add_self_loops, remove_self_loops, getgraph
2727

2828
include("generate.jl")
2929
export rand_graph
@@ -38,6 +38,6 @@ export
3838
# from SparseArrays
3939
sprand, sparse, blockdiag,
4040
# from Flux
41-
batch
41+
batch, unbatch
4242

4343
end #module

src/GNNGraphs/generate.jl

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,52 @@
11
"""
2-
rand_graph(n, m; directed=false, kws...)
2+
rand_graph(n, m; bidirected=true, kws...)
33
4-
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes.
4+
Generate a random (Erdós-Renyi) `GNNGraph` with `n` nodes
5+
and `m` edges.
56
6-
If `directed=false` the output will contain `2m` edges:
7-
the reverse edge of each edge will be present.
8-
If `directed=true` instead, `m` unrelated edges are generated.
7+
If `bidirected=true` the reverse edge of each edge will be present.
8+
If `bidirected=false` instead, `m` unrelated edges are generated.
9+
In any case, the output graph will contain no self-loops or multi-edges.
910
10-
Additional keyword argument will be fed to the [`GNNGraph`](@ref) constructor.
11+
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
12+
13+
# Usage
14+
15+
```juliarepl
16+
julia> g = rand_graph(5, 4, bidirected=false)
17+
GNNGraph:
18+
num_nodes = 5
19+
num_edges = 4
20+
num_graphs = 1
21+
ndata:
22+
edata:
23+
gdata:
24+
25+
26+
julia> edge_index(g)
27+
([1, 3, 3, 4], [5, 4, 5, 2])
28+
29+
# In the bidirected case, edge data will be duplicated on the reverse edges if needed.
30+
julia> g = rand_graph(5, 4, edata=rand(16, 2))
31+
GNNGraph:
32+
num_nodes = 5
33+
num_edges = 4
34+
num_graphs = 1
35+
ndata:
36+
edata:
37+
e => (16, 4)
38+
gdata:
39+
40+
# Each edge has a reverse
41+
julia> edge_index(g)
42+
([1, 3, 3, 4], [3, 4, 1, 3])
43+
44+
```
1145
"""
12-
function rand_graph(n::Integer, m::Integer; directed=false, kws...)
13-
return GNNGraph(Graphs.erdos_renyi(n, m, is_directed=directed); kws...)
46+
function rand_graph(n::Integer, m::Integer; bidirected=true, kws...)
47+
if bidirected
48+
@assert iseven(m) "Need even number of edges for bidirected graphs, given m=$m."
49+
end
50+
m2 = bidirected ? m÷2 : m
51+
return GNNGraph(Graphs.erdos_renyi(n, m2, is_directed=!bidirected); kws...)
1452
end

src/GNNGraphs/gnngraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ functionality from that library.
5656
Optionally, also edge weights can be given: `(source, target, weights)`.
5757
- `:sparse`. A sparse adjacency matrix representation.
5858
- `:dense`. A dense adjacency matrix representation.
59-
Default `:coo`.
59+
Defaults to `:coo`, currently the most supported type.
6060
- `dir`: The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
6161
Possible values are `:out` and `:in`. Default `:out`.
6262
- `num_nodes`: The number of nodes. If not specified, inferred from `g`. Default `nothing`.

src/GNNGraphs/query.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Graphs.is_directed(::Type{<:GNNGraph}) = true
6262
Return the adjacency list representation (a vector of vectors)
6363
of the graph `g`.
6464
65-
Calling `a` the adjacency list, if `dir=:out`
65+
Calling `a` the adjacency list, if `dir=:out` than
6666
`a[i]` will contain the neighbors of node `i` through
6767
outgoing edges. If `dir=:in`, it will contain neighbors from
6868
incoming edges instead.
@@ -75,7 +75,7 @@ end
7575

7676
function Graphs.adjacency_matrix(g::GNNGraph{<:COO_T}, T::DataType=Int; dir=:out)
7777
if g.graph[1] isa CuVector
78-
# TODO revisi after https://github.com/JuliaGPU/CUDA.jl/pull/1152
78+
# TODO revisit after https://github.com/JuliaGPU/CUDA.jl/pull/1152
7979
A, n, m = to_dense(g.graph, T, num_nodes=g.num_nodes)
8080
else
8181
A, n, m = to_sparse(g.graph, T, num_nodes=g.num_nodes)

src/GNNGraphs/transform.jl

Lines changed: 123 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ end
5454
"""
5555
add_edges(g::GNNGraph, s::AbstractVector, t::AbstractVector; [edata])
5656
57-
Add to graph `g` the edges with source nodes `s` and target nodes `t`.
58-
57+
Add to graph `g` the edges with source nodes `s` and target nodes `t`.
5958
"""
6059
function add_edges(g::GNNGraph{<:COO_T},
6160
snew::AbstractVector{<:Integer},
@@ -79,6 +78,25 @@ function add_edges(g::GNNGraph{<:COO_T},
7978
g.ndata, edata, g.gdata)
8079
end
8180

81+
82+
"""
83+
add_nodes(g::GNNGraph, n; [ndata])
84+
85+
Add `n` new nodes to graph `g`. In the
86+
new graph, these nodes will have indexes from `g.num_nodes + 1`
87+
to `g.num_nodes + n`.
88+
"""
89+
function add_nodes(g::GNNGraph{<:COO_T}, n::Integer; ndata=(;))
90+
ndata = normalize_graphdata(ndata, default_name=:x, n=n)
91+
ndata = cat_features(g.ndata, ndata)
92+
93+
GNNGraph(g.graph,
94+
g.num_nodes + n, g.num_edges, g.num_graphs,
95+
g.graph_indicator,
96+
ndata, g.edata, g.gdata)
97+
end
98+
99+
82100
function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
83101
nv1, nv2 = g1.num_nodes, g2.num_nodes
84102
if g1.graph isa COO_T
@@ -117,8 +135,6 @@ function SparseArrays.blockdiag(A1::AbstractMatrix, A2::AbstractMatrix)
117135
O2 A2]
118136
end
119137

120-
### Cat public interfaces #############
121-
122138
"""
123139
blockdiag(xs::GNNGraph...)
124140
@@ -133,14 +149,115 @@ function SparseArrays.blockdiag(g1::GNNGraph, gothers::GNNGraph...)
133149
end
134150

135151
"""
136-
batch(xs::Vector{<:GNNGraph})
152+
batch(gs::Vector{<:GNNGraph})
137153
138154
Batch together multiple `GNNGraph`s into a single one
139155
containing the total number of original nodes and edges.
140156
141157
Equivalent to [`SparseArrays.blockdiag`](@ref).
158+
See also [`Flux.unbatch`](@ref).
159+
160+
# Usage
161+
162+
```juliarepl
163+
julia> g1 = rand_graph(4, 6, ndata=ones(8, 4))
164+
GNNGraph:
165+
num_nodes = 4
166+
num_edges = 6
167+
num_graphs = 1
168+
ndata:
169+
x => (8, 4)
170+
edata:
171+
gdata:
172+
173+
174+
julia> g2 = rand_graph(7, 4, ndata=zeros(8, 7))
175+
GNNGraph:
176+
num_nodes = 7
177+
num_edges = 4
178+
num_graphs = 1
179+
ndata:
180+
x => (8, 7)
181+
edata:
182+
gdata:
183+
184+
185+
julia> g12 = Flux.batch([g1, g2])
186+
GNNGraph:
187+
num_nodes = 11
188+
num_edges = 10
189+
num_graphs = 2
190+
ndata:
191+
x => (8, 11)
192+
edata:
193+
gdata:
194+
195+
196+
julia> g12.ndata.x
197+
8×11 Matrix{Float64}:
198+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
199+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
200+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
201+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
202+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
203+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
204+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
205+
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
206+
```
207+
"""
208+
Flux.batch(gs::Vector{<:GNNGraph}) = blockdiag(gs...)
209+
210+
142211
"""
143-
Flux.batch(xs::Vector{<:GNNGraph}) = blockdiag(xs...)
212+
unbatch(g::GNNGraph)
213+
214+
Opposite of the [`Flux.batch`](@ref) operation, returns
215+
an array of the individual graphs batched together in `g`.
216+
217+
See also [`Flux.batch`](@ref) and [`getgraph`](@ref).
218+
219+
# Usage
220+
221+
```juliarepl
222+
julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])
223+
GNNGraph:
224+
num_nodes = 19
225+
num_edges = 16
226+
num_graphs = 3
227+
ndata:
228+
edata:
229+
gdata:
230+
231+
julia> Flux.unbatch(gbatched)
232+
3-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
233+
GNNGraph:
234+
num_nodes = 5
235+
num_edges = 6
236+
num_graphs = 1
237+
ndata:
238+
edata:
239+
gdata:
240+
241+
GNNGraph:
242+
num_nodes = 10
243+
num_edges = 8
244+
num_graphs = 1
245+
ndata:
246+
edata:
247+
gdata:
248+
249+
GNNGraph:
250+
num_nodes = 4
251+
num_edges = 2
252+
num_graphs = 1
253+
ndata:
254+
edata:
255+
gdata:
256+
```
257+
"""
258+
function Flux.unbatch(g::GNNGraph)
259+
[getgraph(g, i) for i in 1:g.num_graphs]
260+
end
144261

145262

146263
"""

test/GNNGraphs/generate.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
@testset "generate" begin
22
@testset "rand_graph" begin
33
n, m = 10, 20
4+
m2 = m ÷ 2
45
x = rand(3, n)
5-
e = rand(4, m)
6+
e = rand(4, m2)
67
g = rand_graph(n, m, ndata=x, edata=e, graph_type=GRAPH_T)
78
@test g.num_nodes == n
8-
@test g.num_edges == 2m
9+
@test g.num_edges == m
910
@test g.ndata.x === x
1011
if GRAPH_T == :coo
1112
s, t = edge_index(g)
12-
@test s[1:m] == t[m+1:end]
13-
@test t[1:m] == s[m+1:end]
14-
@test g.edata.e[:,1:m] == e
15-
@test g.edata.e[:,m+1:end] == e
13+
@test s[1:m2] == t[m2+1:end]
14+
@test t[1:m2] == s[m2+1:end]
15+
@test g.edata.e[:,1:m2] == e
16+
@test g.edata.e[:,m2+1:end] == e
1617
end
17-
g = rand_graph(n, m, directed=true, graph_type=GRAPH_T)
18+
g = rand_graph(n, m, bidirected=false, graph_type=GRAPH_T)
1819
@test g.num_nodes == n
1920
@test g.num_edges == m
2021
end

test/GNNGraphs/transform.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@
4242
@test g123.gdata.u == [g1.gdata.u, g2.gdata.u, g3.gdata.u]
4343
end
4444

45+
@testset "unbatch" begin
46+
g1 = rand_graph(10, 20)
47+
g2 = rand_graph(5, 10)
48+
g12 = Flux.batch([g1, g2])
49+
gs = Flux.unbatch([g1,g2])
50+
@test length(gs) == 2
51+
@test gs[1].num_nodes == 10
52+
@test gs[1].num_edges == 20
53+
@test gs[1].num_graphs == 1
54+
@test gs[2].num_nodes == 5
55+
@test gs[2].num_edges == 10
56+
@test gs[2].num_graphs == 1
57+
end
58+
4559
@testset "getgraph" begin
4660
g1 = GNNGraph(random_regular_graph(10,2), ndata=rand(16,10), graph_type=GRAPH_T)
4761
g2 = GNNGraph(random_regular_graph(4,2), ndata=rand(16,4), graph_type=GRAPH_T)
@@ -80,4 +94,15 @@
8094
@test all(gnew.edata.e2[:,5] .== 0)
8195
end
8296
end
97+
98+
@testset "add_nodes" begin
99+
if GRAPH_T == :coo
100+
g = rand_graph(6, 4, ndata=rand(2, 6), graph_type=GRAPH_T)
101+
gnew = add_nodes(g, 5, ndata=ones(2, 5))
102+
@test gnew.num_nodes == g.num_nodes + 5
103+
@test gnew.num_edges == g.num_edges
104+
@test gnew.num_graphs == g.num_graphs
105+
@test all(gnew.ndata.x[:,7:11] .== 1)
106+
end
107+
end
83108
end

0 commit comments

Comments
 (0)