Skip to content

Commit a4c139a

Browse files
committed
radius_graph
1 parent 78792f0 commit a4c139a

File tree

3 files changed

+95
-1
lines changed

3 files changed

+95
-1
lines changed

src/GNNGraphs/GNNGraphs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ export add_nodes,
6060

6161
include("generate.jl")
6262
export rand_graph,
63-
knn_graph
63+
knn_graph,
64+
radius_graph
6465

6566
include("sampling.jl")
6667
export sample_neighbors

src/GNNGraphs/generate.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,78 @@ function knn_graph(points::AbstractMatrix, k::Int;
127127
end
128128
return g
129129
end
130+
131+
"""
132+
radius_graph(points::AbstractMatrix,
133+
r::AbstractFloat;
134+
graph_indicator = nothing,
135+
self_loops = false,
136+
dir = :in,
137+
kws...)
138+
139+
Create a graph where each node is linked
140+
to its neighbors within a given distance `r`.
141+
142+
# Arguments
143+
144+
- `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes.
145+
- `r`: The radius.
146+
- `graph_indicator`: Either nothing or a vector containing the graph assigment of each node,
147+
in which case the returned graph will be a batch of graphs.
148+
- `self_loops`: If `true`, consider the node itself among its `k` nearest neighbors, in which
149+
case the graph will contain self-loops.
150+
- `dir`: The direction of the edges. If `dir=:in` edges go from the `k`
151+
neighbors to the central node. If `dir=:out` we have the opposite
152+
direction.
153+
- `kws`: Further keyword arguments will be passed to the [`GNNGraph ](@ref) constructor.
154+
155+
# Examples
156+
157+
```juliarepl
158+
julia> n, r = 10, 0.75;
159+
160+
julia> x = rand(3, n);
161+
162+
julia> g = radius_graph(x, r)
163+
GNNGraph:
164+
num_nodes = 10
165+
num_edges = 46
166+
167+
julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2];
168+
169+
julia> g = radius_graph(x, r; graph_indicator)
170+
GNNGraph:
171+
num_nodes = 10
172+
num_edges = 20
173+
num_graphs = 2
174+
175+
```
176+
"""
177+
function radius_graph(points::AbstractMatrix, r::AbstractFloat;
178+
graph_indicator = nothing,
179+
self_loops = false,
180+
dir = :in,
181+
kws...)
182+
183+
if graph_indicator !== nothing
184+
d, n = size(points)
185+
@assert graph_indicator isa AbstractVector{<:Integer}
186+
@assert length(graph_indicator) == n
187+
188+
# Make sure that the distance between points in different graphs
189+
# is always larger than r.
190+
dummy_feature = 2r .* reshape(graph_indicator, 1, n)
191+
points = vcat(points, dummy_feature)
192+
end
193+
194+
balltree = NearestNeighbors.BallTree(points)
195+
196+
sortres = false
197+
idxs = NearestNeighbors.inrange(balltree, points, r, sortres)
198+
199+
g = GNNGraph(idxs; dir, graph_indicator, kws...)
200+
if !self_loops
201+
g = remove_self_loops(g)
202+
end
203+
return g
204+
end

test/GNNGraphs/generate.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,22 @@
5050
@test all(6 .<= s[ne+1:end] .<= 10)
5151
@test all(6 .<= t[ne+1:end] .<= 10)
5252
end
53+
54+
@testset "radius_graph" begin
55+
n, r = 10, 0.5
56+
x = rand(3, n)
57+
g = radius_graph(x, r; graph_type=GRAPH_T)
58+
@test g.num_nodes == 10
59+
@test has_self_loops(g) == false
60+
61+
g = radius_graph(x, r; dir=:out, self_loops=true, graph_type=GRAPH_T)
62+
@test g.num_nodes == 10
63+
@test has_self_loops(g) == true
64+
65+
graph_indicator = [1,1,1,1,1,2,2,2,2,2]
66+
g = radius_graph(x, r; graph_indicator, graph_type=GRAPH_T)
67+
@test g.num_graphs == 2
68+
s, t = edge_index(g)
69+
@test (s.>5) == (t.>5)
70+
end
5371
end

0 commit comments

Comments
 (0)