Skip to content

Commit fbb332f

Browse files
authored
Merge pull request #851 from igraph/feat/nearest-neighbor
2 parents 6630199 + fb6f407 commit fbb332f

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed

src/_igraph/convert.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4069,6 +4069,21 @@ int igraphmodule_PyObject_to_pagerank_algo_t(PyObject *o, igraph_pagerank_algo_t
40694069
TRANSLATE_ENUM_WITH(pagerank_algo_tt);
40704070
}
40714071

4072+
/**
4073+
* \ingroup python_interface_conversion
4074+
* \brief Converts a Python object to an igraph \c igraph_metric_t
4075+
*/
4076+
int igraphmodule_PyObject_to_metric_t(PyObject *o, igraph_metric_t *result) {
4077+
static igraphmodule_enum_translation_table_entry_t metric_tt[] = {
4078+
{"euclidean", IGRAPH_METRIC_EUCLIDEAN},
4079+
{"l2", IGRAPH_METRIC_L2}, /* alias to the previous */
4080+
{"manhattan", IGRAPH_METRIC_MANHATTAN},
4081+
{"l1", IGRAPH_METRIC_L1}, /* alias to the previous */
4082+
{0,0}
4083+
};
4084+
TRANSLATE_ENUM_WITH(metric_tt);
4085+
}
4086+
40724087
/**
40734088
* \ingroup python_interface_conversion
40744089
* \brief Converts a Python object to an igraph \c igraph_edge_type_sw_t

src/_igraph/convert.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ int igraphmodule_PyObject_to_laplacian_normalization_t(PyObject *o, igraph_lapla
7878
int igraphmodule_PyObject_to_layout_grid_t(PyObject *o, igraph_layout_grid_t *result);
7979
int igraphmodule_PyObject_to_lpa_variant_t(PyObject *o, igraph_lpa_variant_t *result);
8080
int igraphmodule_PyObject_to_loops_t(PyObject *o, igraph_loops_t *result);
81+
int igraphmodule_PyObject_to_metric_t(PyObject *o, igraph_metric_t *result);
8182
int igraphmodule_PyObject_to_mst_algorithm_t(PyObject *o, igraph_mst_algorithm_t *result);
8283
int igraphmodule_PyObject_to_neimode_t(PyObject *o, igraph_neimode_t *result);
8384
int igraphmodule_PyObject_to_pagerank_algo_t(PyObject *o, igraph_pagerank_algo_t *result);

src/_igraph/graphobject.c

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14034,6 +14034,46 @@ PyObject *igraphmodule_Graph_random_walk(igraphmodule_GraphObject * self,
1403414034
}
1403514035
}
1403614036

14037+
/**********************************************************************
14038+
* Spatial graphs *
14039+
**********************************************************************/
14040+
14041+
PyObject *igraphmodule_Graph_Nearest_Neighbor_Graph(PyTypeObject *type,
14042+
PyObject *args, PyObject *kwds) {
14043+
static char *kwlist[] = {"points", "k", "r", "metric", "directed", NULL};
14044+
PyObject *points_o = Py_None, *metric_o = Py_None, *directed_o = Py_False;
14045+
double r = -1;
14046+
Py_ssize_t k = 1;
14047+
igraph_matrix_t points;
14048+
igraphmodule_GraphObject *self;
14049+
igraph_t graph;
14050+
igraph_metric_t metric = IGRAPH_METRIC_EUCLIDEAN;
14051+
14052+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|ndOO", kwlist,
14053+
&points_o, &k, &r, &metric_o, &directed_o)) {
14054+
return NULL;
14055+
}
14056+
14057+
if (igraphmodule_PyObject_to_metric_t(metric_o, &metric)) {
14058+
return NULL;
14059+
}
14060+
14061+
if (igraphmodule_PyObject_to_matrix_t(points_o, &points, "points")) {
14062+
return NULL;
14063+
}
14064+
14065+
if (igraph_nearest_neighbor_graph(&graph, &points, metric, k, r, PyObject_IsTrue(directed_o))) {
14066+
igraph_matrix_destroy(&points);
14067+
return igraphmodule_handle_igraph_error();
14068+
}
14069+
14070+
igraph_matrix_destroy(&points);
14071+
14072+
CREATE_GRAPH_FROM_TYPE(self, graph, type);
14073+
14074+
return (PyObject *) self;
14075+
}
14076+
1403714077
/**********************************************************************
1403814078
* Special internal methods that you won't need to mess around with *
1403914079
**********************************************************************/
@@ -18894,6 +18934,22 @@ struct PyMethodDef igraphmodule_Graph_methods[] = {
1889418934
" the given length (shorter if the random walk got stuck).\n"
1889518935
},
1889618936

18937+
/**********************/
18938+
/* SPATIAL GRAPHS */
18939+
/**********************/
18940+
{"Nearest_Neighbor_Graph", (PyCFunction)igraphmodule_Graph_Nearest_Neighbor_Graph,
18941+
METH_VARARGS | METH_CLASS | METH_KEYWORDS,
18942+
"Nearest_Neighbor_Graph(points, k=1, r=-1, metric=\"euclidean\", directed=False)\n--\n\n"
18943+
"Constructs a k nearest neighbor graph of a give point set. Each point is\n"
18944+
"connected to at most k spatial neighbors within a radius of 1.\n\n"
18945+
"@param points: coordinates of the points to use, in an arbitrary number of dimensions\n"
18946+
"@param k: at most how many neighbors to connect to. Pass a negative value to ignore\n"
18947+
"@param r: only neighbors within this radius are considered. Pass a negative value to ignore\n"
18948+
"@param metric: the metric to use. C{\"euclidean\"} and C{\"manhattan\"} are supported.\n"
18949+
"@param directed: whethe to create directed edges.\n"
18950+
"@return: the nearest neighbor graph.\n"
18951+
},
18952+
1889718953
/**********************/
1889818954
/* INTERNAL FUNCTIONS */
1889918955
/**********************/

tests/test_generators.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,27 @@ def testDataFrame(self):
890890
edges = pd.DataFrame(np.array([[0, 1], [1, np.nan], [1, 2]]), dtype="Int64")
891891
Graph.DataFrame(edges)
892892

893+
def testNearestNeighborGraph(self):
894+
points = [[0,0], [1,2], [-3, -3]]
895+
896+
g = Graph.Nearest_Neighbor_Graph(points)
897+
# expecting 1 - 2, 3 - 1
898+
self.assertFalse(g.is_directed())
899+
self.assertEqual(g.vcount(), 3)
900+
self.assertEqual(g.ecount(), 2)
901+
902+
g = Graph.Nearest_Neighbor_Graph(points, directed=True)
903+
# expecting 1 <-> 2, 3 -> 1
904+
self.assertTrue(g.is_directed())
905+
self.assertEqual(g.vcount(), 3)
906+
self.assertEqual(g.ecount(), 3)
907+
908+
# expecting a complete graph
909+
g = Graph.Nearest_Neighbor_Graph(points, k=2)
910+
self.assertFalse(g.is_directed())
911+
self.assertEqual(g.vcount(), 3)
912+
self.assertTrue(g.is_complete())
913+
893914

894915
def suite():
895916
generator_suite = unittest.defaultTestLoader.loadTestsFromTestCase(GeneratorTests)

0 commit comments

Comments
 (0)