Skip to content

Commit b2115dd

Browse files
implement knn_graph
1 parent b230ab1 commit b2115dd

File tree

5 files changed

+46
-9
lines changed

5 files changed

+46
-9
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1717
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1818
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1919
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
20+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
2021
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2122
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2223
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
@@ -34,6 +35,7 @@ Graphs = "1.4"
3435
KrylovKit = "0.5"
3536
LearnBase = "0.4, 0.5"
3637
MacroTools = "0.5"
38+
NearestNeighbors = "0.4"
3739
NNlib = "0.7"
3840
NNlibCUDA = "0.1"
3941
Reexport = "1"

src/GNNGraphs/GNNGraphs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Graphs
77
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, has_self_loops, is_directed
88
import Flux
99
using Flux: batch
10+
import NearestNeighbors
1011
import NNlib
1112
import LearnBase
1213
import StatsBase
@@ -53,7 +54,8 @@ export add_nodes,
5354
blockdiag
5455

5556
include("generate.jl")
56-
export rand_graph
57+
export rand_graph,
58+
knn_graph
5759

5860
include("operators.jl")
5961
# Base.intersect

src/GNNGraphs/generate.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,19 @@ function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
4444
m2 = bidirected ? m÷2 : m
4545
return GNNGraph(Graphs.erdos_renyi(n, m2; is_directed=!bidirected, seed); kws...)
4646
end
47+
48+
49+
function knn_graph(points::AbstractMatrix, k::Int; self_loops=false, dir=:in, kws...)
50+
kdtree = NearestNeighbors.KDTree(points)
51+
sortres = false
52+
if !self_loops
53+
k += 1
54+
end
55+
idxs, dists = NearestNeighbors.knn(kdtree, points, k, sortres)
56+
# return idxs
57+
g = GNNGraph(idxs; dir, kws...)
58+
if !self_loops
59+
g = remove_self_loops(g)
60+
end
61+
return g
62+
end

src/GNNGraphs/gnngraph.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,26 +198,26 @@ function GNNGraph(g::GNNGraph; ndata=g.ndata, edata=g.edata, gdata=g.gdata, grap
198198
end
199199

200200
function Base.show(io::IO, g::GNNGraph)
201-
println(io, "GNNGraph:
201+
print(io, "GNNGraph:
202202
num_nodes = $(g.num_nodes)
203203
num_edges = $(g.num_edges)")
204-
g.num_graphs > 1 && println("num_graphs = $(g.num_graphs)")
204+
g.num_graphs > 1 && print("\nnum_graphs = $(g.num_graphs)")
205205
if !isempty(g.ndata)
206-
println(io, " ndata:")
206+
print(io, "\n ndata:")
207207
for k in keys(g.ndata)
208-
println(io, " $k => $(size(g.ndata[k]))")
208+
print(io, "\n $k => $(size(g.ndata[k]))")
209209
end
210210
end
211211
if !isempty(g.edata)
212-
println(io, " edata:")
212+
print(io, "\n edata:")
213213
for k in keys(g.edata)
214-
println(io, " $k => $(size(g.edata[k]))")
214+
print(io, "\n $k => $(size(g.edata[k]))")
215215
end
216216
end
217217
if !isempty(g.gdata)
218-
println(io, " gdata:")
218+
print(io, "\n gdata:")
219219
for k in keys(g.gdata)
220-
println(io, " $k => $(size(g.gdata[k]))")
220+
print(io, "\n $k => $(size(g.gdata[k]))")
221221
end
222222
end
223223
end

test/GNNGraphs/generate.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,21 @@
2424
g2 = rand_graph(n, m, bidirected=false, seed=17, graph_type=GRAPH_T)
2525
@test edge_index(g2) == edge_index(g)
2626
end
27+
28+
@testset "knn_graph" begin
29+
n = 10
30+
k = 3
31+
x = rand(3, n)
32+
g = knn_graph(x, k; graph_type=GRAPH_T)
33+
@test g.num_nodes == 10
34+
@test g.num_edges == n*k
35+
@test degree(g, dir=:in) == fill(k, n)
36+
@test has_self_loops(g) == false
37+
38+
g = knn_graph(x, k; dir=:out, self_loops=true, graph_type=GRAPH_T)
39+
@test g.num_nodes == 10
40+
@test g.num_edges == n*k
41+
@test degree(g, dir=:out) == fill(k, n)
42+
@test has_self_loops(g) == true
43+
end
2744
end

0 commit comments

Comments
 (0)