@@ -127,3 +127,78 @@ function knn_graph(points::AbstractMatrix, k::Int;
127
127
end
128
128
return g
129
129
end
130
+
131
+ """
132
+ radius_graph(points::AbstractMatrix,
133
+ r::AbstractFloat;
134
+ graph_indicator = nothing,
135
+ self_loops = false,
136
+ dir = :in,
137
+ kws...)
138
+
139
+ Create a graph where each node is linked
140
+ to its neighbors within a given distance `r`.
141
+
142
+ # Arguments
143
+
144
+ - `points`: A num_features × num_nodes matrix storing the Euclidean positions of the nodes.
145
+ - `r`: The radius.
146
+ - `graph_indicator`: Either nothing or a vector containing the graph assigment of each node,
147
+ in which case the returned graph will be a batch of graphs.
148
+ - `self_loops`: If `true`, consider the node itself among its `k` nearest neighbors, in which
149
+ case the graph will contain self-loops.
150
+ - `dir`: The direction of the edges. If `dir=:in` edges go from the `k`
151
+ neighbors to the central node. If `dir=:out` we have the opposite
152
+ direction.
153
+ - `kws`: Further keyword arguments will be passed to the [`GNNGraph ](@ref) constructor.
154
+
155
+ # Examples
156
+
157
+ ```juliarepl
158
+ julia> n, r = 10, 0.75;
159
+
160
+ julia> x = rand(3, n);
161
+
162
+ julia> g = radius_graph(x, r)
163
+ GNNGraph:
164
+ num_nodes = 10
165
+ num_edges = 46
166
+
167
+ julia> graph_indicator = [1,1,1,1,1,2,2,2,2,2];
168
+
169
+ julia> g = radius_graph(x, r; graph_indicator)
170
+ GNNGraph:
171
+ num_nodes = 10
172
+ num_edges = 20
173
+ num_graphs = 2
174
+
175
+ ```
176
+ """
177
+ function radius_graph (points:: AbstractMatrix , r:: AbstractFloat ;
178
+ graph_indicator = nothing ,
179
+ self_loops = false ,
180
+ dir = :in ,
181
+ kws... )
182
+
183
+ if graph_indicator != = nothing
184
+ d, n = size (points)
185
+ @assert graph_indicator isa AbstractVector{<: Integer }
186
+ @assert length (graph_indicator) == n
187
+
188
+ # Make sure that the distance between points in different graphs
189
+ # is always larger than r.
190
+ dummy_feature = 2 r .* reshape (graph_indicator, 1 , n)
191
+ points = vcat (points, dummy_feature)
192
+ end
193
+
194
+ balltree = NearestNeighbors. BallTree (points)
195
+
196
+ sortres = false
197
+ idxs = NearestNeighbors. inrange (balltree, points, r, sortres)
198
+
199
+ g = GNNGraph (idxs; dir, graph_indicator, kws... )
200
+ if ! self_loops
201
+ g = remove_self_loops (g)
202
+ end
203
+ return g
204
+ end
0 commit comments