Skip to content

Commit 1cbb5db

Browse files
support batched graphs
1 parent b2115dd commit 1cbb5db

File tree

8 files changed

+111
-17
lines changed

8 files changed

+111
-17
lines changed

src/GNNGraphs/convert.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function to_dense(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
6969
num_nodes = size(A, 1)
7070
@assert num_nodes == size(A, 2)
7171
# @assert all(x -> (x == 1) || (x == 0), A)
72-
num_edges = round(Int, sum(A))
72+
num_edges = numnonzeros(A)
7373
if dir == :in
7474
A = A'
7575
end
@@ -85,7 +85,7 @@ function to_dense(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing)
8585
num_edges = sum(length.(adj_list))
8686
@assert num_nodes > 0
8787
T = T === nothing ? eltype(adj_list[1]) : T
88-
A = similar(adj_list[1], T, (num_nodes, num_nodes))
88+
A = fill!(similar(adj_list[1], T, (num_nodes, num_nodes)), 0)
8989
if dir == :out
9090
for (i, neigs) in enumerate(adj_list)
9191
A[i, neigs] .= 1
@@ -129,13 +129,13 @@ function to_sparse(A::ADJMAT_T, T=nothing; dir=:out, num_nodes=nothing)
129129
end
130130
if !(A isa AbstractSparseMatrix)
131131
A = sparse(A)
132-
end
132+
end
133133
return A, num_nodes, num_edges
134134
end
135135

136136
function to_sparse(adj_list::ADJLIST_T, T=nothing; dir=:out, num_nodes=nothing)
137137
coo, num_nodes, num_edges = to_coo(adj_list; dir, num_nodes)
138-
return to_sparse(coo; dir, num_nodes)
138+
return to_sparse(coo; num_nodes)
139139
end
140140

141141
function to_sparse(coo::COO_T, T=nothing; dir=:out, num_nodes=nothing)

src/GNNGraphs/generate.jl

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Use a `seed > 0` for reproducibility.
1212
1313
Additional keyword arguments will be passed to the [`GNNGraph`](@ref) constructor.
1414
15-
# Usage
15+
# Examples
1616
1717
```juliarepl
1818
julia> g = rand_graph(5, 4, bidirected=false)
@@ -46,15 +46,82 @@ function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
4646
end
4747

4848

49-
function knn_graph(points::AbstractMatrix, k::Int; self_loops=false, dir=:in, kws...)
49+
"""
50+
knn_graph(points::AbstractMatrix,
51+
k::Int;
52+
graph_indicator = nothing,
53+
self_loops = false,
54+
dir = :in,
55+
kws...)
56+
57+
Create a `k`-nearest neighbor graph where each node is linked
58+
to its `k` closest `points`.
59+
60+
# Arguments
61+
62+
- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes.
63+
- `k`: The number of neighbors considered in the kNN algorithm.
64+
- `graph_indicator`: Either nothing or a vector containing the graph assigment of each node,
65+
in which case the returned graph will be a batch of graphs.
66+
- `self_loops`: If `true`, consider the node itself among its `k` nearest neighbors, in which
67+
case the graph will contain self-loops.
68+
- `dir`: The direction of the edges. If `dir=:in` edges go from the `k`
69+
neighbors to the central node. If `dir=:out` we have the opposite
70+
direction.
71+
- `kws`: Further keyword arguments will be passed to the [`GNNGraph ](@ref) constructor.
72+
73+
# Examples
74+
75+
```juliarepl
76+
julia> n, k = 10, 3;
77+
78+
julia> x = rand(3, n);
79+
80+
julia> g = knn_graph(x, k)
81+
GNNGraph:
82+
num_nodes = 10
83+
num_edges = 30
84+
85+
julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2];
86+
87+
julia> g = knn_graph(x, k; graph_indicator)
88+
GNNGraph:
89+
num_nodes = 10
90+
num_edges = 30
91+
num_graphs = 2
92+
93+
```
94+
"""
95+
function knn_graph(points::AbstractMatrix, k::Int;
96+
graph_indicator = nothing,
97+
self_loops = false,
98+
dir = :in,
99+
kws...)
100+
101+
if graph_indicator !== nothing
102+
d, n = size(points)
103+
@assert graph_indicator isa AbstractVector{<:Integer}
104+
@assert length(graph_indicator) == n
105+
# All graphs in the batch must have at least k nodes.
106+
cm = StatsBase.countmap(graph_indicator)
107+
@assert all(values(cm) .>= k)
108+
109+
# Make sure that the distance between points in different graphs
110+
# is always larger than any distance within the same graph.
111+
points = points .- minimum(points)
112+
points = points ./ maximum(points)
113+
dummy_feature = 2d .* reshape(graph_indicator, 1, n)
114+
points = vcat(points, dummy_feature)
115+
end
116+
50117
kdtree = NearestNeighbors.KDTree(points)
51-
sortres = false
52118
if !self_loops
53119
k += 1
54120
end
121+
sortres = false
55122
idxs, dists = NearestNeighbors.knn(kdtree, points, k, sortres)
56-
# return idxs
57-
g = GNNGraph(idxs; dir, kws...)
123+
124+
g = GNNGraph(idxs; dir, graph_indicator, kws...)
58125
if !self_loops
59126
g = remove_self_loops(g)
60127
end

src/GNNGraphs/gnngraph.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ functionality from that library.
6565
- `edata`: Edge features. A named tuple of arrays whose last dimension has size `num_edges`.
6666
- `gdata`: Graph features. A named tuple of arrays whose last dimension has size `num_graphs`.
6767
68-
# Usage.
68+
# Examples
6969
7070
```julia
7171
using Flux, GraphNeuralNetworks
@@ -201,7 +201,7 @@ function Base.show(io::IO, g::GNNGraph)
201201
print(io, "GNNGraph:
202202
num_nodes = $(g.num_nodes)
203203
num_edges = $(g.num_edges)")
204-
g.num_graphs > 1 && print("\nnum_graphs = $(g.num_graphs)")
204+
g.num_graphs > 1 && print("\n num_graphs = $(g.num_graphs)")
205205
if !isempty(g.ndata)
206206
print(io, "\n ndata:")
207207
for k in keys(g.ndata)

src/GNNGraphs/transform.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,22 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
5555
g.ndata, g.edata, g.gdata)
5656
end
5757

58+
59+
function remove_self_loops(g::GNNGraph{<:ADJMAT_T})
60+
@assert g.edata === (;)
61+
A = g.graph
62+
A[diagind(A)] .= 0
63+
if A isa AbstractSparseMatrix
64+
dropzeros!(A)
65+
end
66+
num_edges = numnonzeros(A)
67+
GNNGraph(A,
68+
g.num_nodes, num_edges, g.num_graphs,
69+
g.graph_indicator,
70+
g.ndata, g.edata, g.gdata)
71+
end
72+
73+
5874
"""
5975
remove_multi_edges(g::GNNGraph)
6076
@@ -183,7 +199,7 @@ containing the total number of original nodes and edges.
183199
Equivalent to [`SparseArrays.blockdiag`](@ref).
184200
See also [`Flux.unbatch`](@ref).
185201
186-
# Usage
202+
# Examples
187203
188204
```juliarepl
189205
julia> g1 = rand_graph(4, 6, ndata=ones(8, 4))
@@ -231,7 +247,7 @@ an array of the individual graphs batched together in `g`.
231247
232248
See also [`Flux.batch`](@ref) and [`getgraph`](@ref).
233249
234-
# Usage
250+
# Examples
235251
236252
```juliarepl
237253
julia> gbatched = Flux.batch([rand_graph(5, 6), rand_graph(10, 8), rand_graph(4,2)])

src/GNNGraphs/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ ones_like(x::AbstractArray, T=eltype(x), sz=size(x)) = fill!(similar(x, T, sz),
7171
ones_like(x::SparseMatrixCSC, T=eltype(x), sz=size(x)) = ones(T, sz)
7272
ones_like(x::CUMAT_T, T=eltype(x), sz=size(x)) = CUDA.ones(T, sz)
7373

74+
numnonzeros(a::AbstractSparseMatrix) = nnz(a)
75+
numnonzeros(a::AbstractMatrix) = count(!=(0), a)
7476

7577
# each edge is represented by a number in
7678
# 1:N^2

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ as the input size.
713713
- `init`: Weights' initializer.
714714
- `residual`: Add a residual connection.
715715
716-
# Usage
716+
# Examples
717717
718718
```julia
719719
x = rand(Float32, 2, g.num_nodes)

src/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ providing as input `f` a closure.
3131
with the same batch size.
3232
- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`.
3333
34-
# Usage Examples
34+
# Examples
3535
3636
```julia
3737
using GraphNeuralNetworks, Flux

test/GNNGraphs/generate.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
end
2727

2828
@testset "knn_graph" begin
29-
n = 10
30-
k = 3
29+
n, k = 10, 3
3130
x = rand(3, n)
3231
g = knn_graph(x, k; graph_type=GRAPH_T)
3332
@test g.num_nodes == 10
@@ -40,5 +39,15 @@
4039
@test g.num_edges == n*k
4140
@test degree(g, dir=:out) == fill(k, n)
4241
@test has_self_loops(g) == true
42+
43+
graph_indicator = [1,1,1,1,1,2,2,2,2,2]
44+
g = knn_graph(x, k; graph_indicator, graph_type=GRAPH_T)
45+
@test g.num_graphs == 2
46+
s, t = edge_index(g)
47+
ne = n*k÷2
48+
@test all(1 .<= s[1:ne] .<= 5)
49+
@test all(1 .<= t[1:ne] .<= 5)
50+
@test all(6 .<= s[ne+1:end] .<= 10)
51+
@test all(6 .<= t[ne+1:end] .<= 10)
4352
end
4453
end

0 commit comments

Comments
 (0)