@@ -41,8 +41,8 @@ function rand_graph(n::Integer, m::Integer; bidirected=true, seed=-1, kws...)
41
41
if bidirected
42
42
@assert iseven (m) " Need even number of edges for bidirected graphs, given m=$m ."
43
43
end
44
- m2 = bidirected ? m ÷ 2 : m
45
- return GNNGraph (Graphs. erdos_renyi (n, m2; is_directed= ! bidirected, seed); kws... )
44
+ m2 = bidirected ? m÷ 2 : m
45
+ return GNNGraph (Graphs. erdos_renyi (n, m2; is_directed= ! bidirected, seed); kws... )
46
46
end
47
47
48
48
@@ -92,11 +92,11 @@ GNNGraph:
92
92
93
93
```
94
94
"""
95
- function knn_graph (points:: AbstractMatrix , k:: Int ;
96
- graph_indicator= nothing ,
97
- self_loops= false ,
98
- dir= :in ,
99
- kws... )
95
+ function knn_graph (points:: AbstractMatrix , k:: Int ;
96
+ graph_indicator = nothing ,
97
+ self_loops = false ,
98
+ dir = :in ,
99
+ kws... )
100
100
101
101
if graph_indicator != = nothing
102
102
d, n = size (points)
@@ -105,22 +105,22 @@ function knn_graph(points::AbstractMatrix, k::Int;
105
105
# All graphs in the batch must have at least k nodes.
106
106
cm = StatsBase. countmap (graph_indicator)
107
107
@assert all (values (cm) .>= k)
108
-
108
+
109
109
# Make sure that the distance between points in different graphs
110
110
# is always larger than any distance within the same graph.
111
111
points = points .- minimum (points)
112
112
points = points ./ maximum (points)
113
113
dummy_feature = 2 d .* reshape (graph_indicator, 1 , n)
114
114
points = vcat (points, dummy_feature)
115
115
end
116
-
116
+
117
117
kdtree = NearestNeighbors. KDTree (points)
118
118
if ! self_loops
119
119
k += 1
120
120
end
121
121
sortres = false
122
122
idxs, dists = NearestNeighbors. knn (kdtree, points, k, sortres)
123
-
123
+
124
124
g = GNNGraph (idxs; dir, graph_indicator, kws... )
125
125
if ! self_loops
126
126
g = remove_self_loops (g)
@@ -174,17 +174,17 @@ GNNGraph:
174
174
175
175
```
176
176
"""
177
- function radius_graph (points:: AbstractMatrix , r:: AbstractFloat ;
178
- graph_indicator= nothing ,
179
- self_loops= false ,
180
- dir= :in ,
181
- kws... )
177
+ function radius_graph (points:: AbstractMatrix , r:: AbstractFloat ;
178
+ graph_indicator = nothing ,
179
+ self_loops = false ,
180
+ dir = :in ,
181
+ kws... )
182
182
183
183
if graph_indicator != = nothing
184
184
d, n = size (points)
185
185
@assert graph_indicator isa AbstractVector{<: Integer }
186
186
@assert length (graph_indicator) == n
187
-
187
+
188
188
# Make sure that the distance between points in different graphs
189
189
# is always larger than r.
190
190
dummy_feature = 2 r .* reshape (graph_indicator, 1 , n)
@@ -195,7 +195,7 @@ function radius_graph(points::AbstractMatrix, r::AbstractFloat;
195
195
196
196
sortres = false
197
197
idxs = NearestNeighbors. inrange (balltree, points, r, sortres)
198
-
198
+
199
199
g = GNNGraph (idxs; dir, graph_indicator, kws... )
200
200
if ! self_loops
201
201
g = remove_self_loops (g)
0 commit comments