Skip to content

Commit 21194f7

Browse files
Prerak SinghPrerak Singh
authored andcommitted
added cpp implementation for prims
1 parent f7a6296 commit 21194f7

File tree

7 files changed

+449
-29
lines changed

7 files changed

+449
-29
lines changed
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#include <Python.h>
2+
#include <unordered_map>
3+
#include <queue>
4+
#include <string>
5+
#include <unordered_set>
6+
#include <variant>
7+
#include "GraphEdge.hpp"
8+
#include "AdjacencyList.hpp"
9+
#include "AdjacencyMatrix.hpp"
10+
11+
static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {
12+
PyObject* graph_obj;
13+
const char* source_name;
14+
PyObject* operation;
15+
PyObject* varargs = nullptr;
16+
PyObject* kwargs_dict = nullptr;
17+
18+
static const char* kwlist[] = {"graph", "source_node", "operation", "args", "kwargs", nullptr};
19+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!sO|OO", const_cast<char**>(kwlist),
20+
&AdjacencyListGraphType, &graph_obj,
21+
&source_name, &operation,
22+
&varargs, &kwargs_dict)) {
23+
return nullptr;
24+
}
25+
26+
AdjacencyListGraph* cpp_graph = reinterpret_cast<AdjacencyListGraph*>(graph_obj);
27+
28+
auto it = cpp_graph->node_map.find(source_name);
29+
AdjacencyListGraphNode* start_node = it->second;
30+
31+
std::unordered_set<std::string> visited;
32+
std::queue<AdjacencyListGraphNode*> q;
33+
34+
q.push(start_node);
35+
visited.insert(start_node->name);
36+
37+
while (!q.empty()) {
38+
AdjacencyListGraphNode* node = q.front();
39+
q.pop();
40+
41+
for (const auto& [adj_name, adj_obj] : node->adjacent) {
42+
if (visited.count(adj_name)) continue;
43+
if (!PyObject_IsInstance(adj_obj, (PyObject*)&AdjacencyListGraphNodeType)) continue;
44+
45+
AdjacencyListGraphNode* adj_node = reinterpret_cast<AdjacencyListGraphNode*>(adj_obj);
46+
47+
PyObject* base_args = PyTuple_Pack(2,
48+
reinterpret_cast<PyObject*>(node),
49+
reinterpret_cast<PyObject*>(adj_node));
50+
if (!base_args)
51+
return nullptr;
52+
53+
PyObject* final_args;
54+
if (varargs && PyTuple_Check(varargs)) {
55+
final_args = PySequence_Concat(base_args, varargs);
56+
Py_DECREF(base_args);
57+
if (!final_args)
58+
return nullptr;
59+
} else {
60+
final_args = base_args;
61+
}
62+
63+
PyObject* result = PyObject_Call(operation, final_args, kwargs_dict);
64+
Py_DECREF(final_args);
65+
66+
if (!result)
67+
return nullptr;
68+
69+
Py_DECREF(result);
70+
71+
visited.insert(adj_name);
72+
q.push(adj_node);
73+
}
74+
}
75+
if (PyErr_Occurred()) {
76+
return nullptr;
77+
}
78+
79+
Py_RETURN_NONE;
80+
}
81+
82+
static PyObject* breadth_first_search_adjacency_matrix(PyObject* self, PyObject* args, PyObject* kwargs) {
83+
PyObject* graph_obj;
84+
const char* source_name;
85+
PyObject* operation;
86+
PyObject* varargs = nullptr;
87+
PyObject* kwargs_dict = nullptr;
88+
89+
static const char* kwlist[] = {"graph", "source_node", "operation", "args", "kwargs", nullptr};
90+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!sO|OO", const_cast<char**>(kwlist),
91+
&AdjacencyMatrixGraphType, &graph_obj,
92+
&source_name, &operation,
93+
&varargs, &kwargs_dict)) {
94+
return nullptr;
95+
}
96+
97+
AdjacencyMatrixGraph* cpp_graph = reinterpret_cast<AdjacencyMatrixGraph*>(graph_obj);
98+
99+
auto it = cpp_graph->node_map.find(source_name);
100+
if (it == cpp_graph->node_map.end()) {
101+
PyErr_SetString(PyExc_KeyError, "Source node not found in graph");
102+
return nullptr;
103+
}
104+
AdjacencyMatrixGraphNode* start_node = it->second;
105+
106+
std::unordered_set<std::string> visited;
107+
std::queue<AdjacencyMatrixGraphNode*> q;
108+
109+
q.push(start_node);
110+
visited.insert(source_name);
111+
112+
while (!q.empty()) {
113+
AdjacencyMatrixGraphNode* node = q.front();
114+
q.pop();
115+
116+
std::string node_name = reinterpret_cast<GraphNode*>(node)->name;
117+
auto& neighbors = cpp_graph->matrix[node_name];
118+
119+
for (const auto& [adj_name, connected] : neighbors) {
120+
if (!connected || visited.count(adj_name)) continue;
121+
122+
auto adj_it = cpp_graph->node_map.find(adj_name);
123+
if (adj_it == cpp_graph->node_map.end()) continue;
124+
125+
AdjacencyMatrixGraphNode* adj_node = adj_it->second;
126+
127+
PyObject* base_args = PyTuple_Pack(2,
128+
reinterpret_cast<PyObject*>(node),
129+
reinterpret_cast<PyObject*>(adj_node));
130+
if (!base_args) return nullptr;
131+
132+
PyObject* final_args;
133+
if (varargs && PyTuple_Check(varargs)) {
134+
final_args = PySequence_Concat(base_args, varargs);
135+
Py_DECREF(base_args);
136+
if (!final_args) return nullptr;
137+
} else {
138+
final_args = base_args;
139+
}
140+
141+
PyObject* result = PyObject_Call(operation, final_args, kwargs_dict);
142+
Py_DECREF(final_args);
143+
if (!result) return nullptr;
144+
Py_DECREF(result);
145+
146+
visited.insert(adj_name);
147+
q.push(adj_node);
148+
}
149+
}
150+
151+
if (PyErr_Occurred()) {
152+
return nullptr;
153+
}
154+
155+
Py_RETURN_NONE;
156+
}
157+
158+
static PyObject* minimum_spanning_tree_prim_adjacency_list(PyObject* self, PyObject* args, PyObject* kwargs) {
159+
160+
PyObject* graph_obj;
161+
static const char* kwlist[] = {"graph", nullptr};
162+
163+
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", const_cast<char**>(kwlist),
164+
&AdjacencyListGraphType, &graph_obj)) {
165+
return nullptr;
166+
}
167+
168+
AdjacencyListGraph* graph = reinterpret_cast<AdjacencyListGraph*>(graph_obj);
169+
170+
struct EdgeTuple {
171+
std::string source;
172+
std::string target;
173+
std::variant<std::monostate, int64_t, double, std::string> value;
174+
DataType value_type;
175+
176+
bool operator>(const EdgeTuple& other) const {
177+
if (value_type != other.value_type)
178+
return value_type > other.value_type;
179+
if (std::holds_alternative<int64_t>(value))
180+
return std::get<int64_t>(value) > std::get<int64_t>(other.value);
181+
if (std::holds_alternative<double>(value))
182+
return std::get<double>(value) > std::get<double>(other.value);
183+
if (std::holds_alternative<std::string>(value))
184+
return std::get<std::string>(value) > std::get<std::string>(other.value);
185+
return false;
186+
}
187+
};
188+
189+
std::priority_queue<EdgeTuple, std::vector<EdgeTuple>, std::greater<>> pq;
190+
std::unordered_set<std::string> visited;
191+
192+
PyObject* mst_graph = PyObject_CallObject(reinterpret_cast<PyObject*>(&AdjacencyListGraphType), nullptr);
193+
AdjacencyListGraph* mst = reinterpret_cast<AdjacencyListGraph*>(mst_graph);
194+
195+
std::string start = graph->node_map.begin()->first;
196+
visited.insert(start);
197+
198+
AdjacencyListGraphNode* start_node = graph->node_map[start];
199+
200+
Py_INCREF(start_node);
201+
mst->nodes.push_back(start_node);
202+
mst->node_map[start] = start_node;
203+
204+
for (const auto& [adj_name, _] : start_node->adjacent) {
205+
std::string key = make_edge_key(start, adj_name);
206+
GraphEdge* edge = graph->edges[key];
207+
pq.push({start, adj_name, edge->value, edge->value_type});
208+
}
209+
210+
while (!pq.empty()) {
211+
EdgeTuple edge = pq.top();
212+
pq.pop();
213+
214+
if (visited.count(edge.target)) continue;
215+
visited.insert(edge.target);
216+
217+
for (const std::string& name : {edge.source, edge.target}) {
218+
if (!mst->node_map.count(name)) {
219+
AdjacencyListGraphNode* node = graph->node_map[name];
220+
Py_INCREF(node);
221+
mst->nodes.push_back(node);
222+
mst->node_map[name] = node;
223+
}
224+
}
225+
226+
AdjacencyListGraphNode* u = mst->node_map[edge.source];
227+
AdjacencyListGraphNode* v = mst->node_map[edge.target];
228+
229+
Py_INCREF(v);
230+
Py_INCREF(u);
231+
u->adjacent[edge.target] = reinterpret_cast<PyObject*>(v);
232+
v->adjacent[edge.source] = reinterpret_cast<PyObject*>(u);
233+
234+
std::string key_uv = make_edge_key(edge.source, edge.target);
235+
GraphEdge* new_edge = PyObject_New(GraphEdge, &GraphEdgeType);
236+
Py_INCREF(u);
237+
Py_INCREF(v);
238+
new_edge->source = reinterpret_cast<PyObject*>(u);
239+
new_edge->target = reinterpret_cast<PyObject*>(v);
240+
new (&new_edge->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
241+
new_edge->value_type = edge.value_type;
242+
mst->edges[key_uv] = new_edge;
243+
244+
std::string key_vu = make_edge_key(edge.target, edge.source);
245+
GraphEdge* new_edge_rev = PyObject_New(GraphEdge, &GraphEdgeType);
246+
new_edge_rev->source = reinterpret_cast<PyObject*>(v);
247+
new_edge_rev->target = reinterpret_cast<PyObject*>(u);
248+
new (&new_edge_rev->value) std::variant<std::monostate, int64_t, double, std::string>(edge.value);
249+
new_edge_rev->value_type = edge.value_type;
250+
mst->edges[key_vu] = new_edge_rev;
251+
252+
AdjacencyListGraphNode* next_node = graph->node_map[edge.target];
253+
254+
for (const auto& [adj_name, _] : next_node->adjacent) {
255+
if (visited.count(adj_name)) continue;
256+
std::string key = make_edge_key(edge.target, adj_name);
257+
GraphEdge* adj_edge = graph->edges[key];
258+
pq.push({edge.target, adj_name, adj_edge->value, adj_edge->value_type});
259+
}
260+
}
261+
262+
Py_INCREF(mst);
263+
return reinterpret_cast<PyObject*>(mst);
264+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <Python.h>
2+
#include "Algorithms.hpp"
3+
#include "AdjacencyList.hpp"
4+
#include "AdjacencyMatrix.hpp"
5+
6+
static PyMethodDef AlgorithmsMethods[] = {
7+
{"bfs_adjacency_list", (PyCFunction)breadth_first_search_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency list with callback"},
8+
{"bfs_adjacency_matrix", (PyCFunction)breadth_first_search_adjacency_matrix, METH_VARARGS | METH_KEYWORDS, "Run BFS on adjacency matrix with callback"},
9+
{"minimum_spanning_tree_prim_adjacency_list", (PyCFunction)minimum_spanning_tree_prim_adjacency_list, METH_VARARGS | METH_KEYWORDS, "Run Prim's algorithm on adjacency list"},
10+
{NULL, NULL, 0, NULL}
11+
};
12+
13+
static struct PyModuleDef algorithms_module = {
14+
PyModuleDef_HEAD_INIT,
15+
"_algorithms", NULL, -1, AlgorithmsMethods
16+
};
17+
18+
PyMODINIT_FUNC PyInit__algorithms(void) {
19+
return PyModule_Create(&algorithms_module);
20+
}

pydatastructs/graphs/algorithms.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -330,16 +330,20 @@ def minimum_spanning_tree(graph, algorithm, **kwargs):
330330
should be used only for such graphs. Using with other
331331
types of graphs may lead to unwanted results.
332332
"""
333-
raise_if_backend_is_not_python(
334-
minimum_spanning_tree, kwargs.get('backend', Backend.PYTHON))
335-
import pydatastructs.graphs.algorithms as algorithms
336-
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
337-
if not hasattr(algorithms, func):
338-
raise NotImplementedError(
339-
"Currently %s algoithm for %s implementation of graphs "
340-
"isn't implemented for finding minimum spanning trees."
341-
%(algorithm, graph._impl))
342-
return getattr(algorithms, func)(graph)
333+
backend = kwargs.get('backend', Backend.PYTHON)
334+
if backend == Backend.PYTHON:
335+
import pydatastructs.graphs.algorithms as algorithms
336+
func = "_minimum_spanning_tree_" + algorithm + "_" + graph._impl
337+
if not hasattr(algorithms, func):
338+
raise NotImplementedError(
339+
"Currently %s algoithm for %s implementation of graphs "
340+
"isn't implemented for finding minimum spanning trees."
341+
%(algorithm, graph._impl))
342+
return getattr(algorithms, func)(graph)
343+
else:
344+
from pydatastructs.graphs._backend.cpp._algorithms import minimum_spanning_tree_prim_adjacency_list
345+
if graph._impl == "adjacency_list" and algorithm == 'prim':
346+
return minimum_spanning_tree_prim_adjacency_list(graph)
343347

344348
def _minimum_spanning_tree_parallel_kruskal_adjacency_list(graph, num_threads):
345349
mst = _generate_mst_object(graph)

pydatastructs/graphs/tests/test_adjacency_list.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,41 @@ def test_adjacency_list():
4242

4343
assert raises(ValueError, lambda: g.add_edge('u', 'v'))
4444
assert raises(ValueError, lambda: g.add_edge('v', 'x'))
45+
46+
v_4 = AdjacencyListGraphNode('v_4', 4, backend = Backend.CPP)
47+
v_5 = AdjacencyListGraphNode('v_5', 5, backend = Backend.CPP)
48+
g2 = Graph(v_4,v_5,implementation = 'adjacency_list', backend = Backend.CPP)
49+
v_6 = AdjacencyListGraphNode('v_6', 6, backend = Backend.CPP)
50+
assert raises(ValueError, lambda: g2.add_vertex(v_5))
51+
g2.add_vertex(v_6)
52+
g2.add_edge('v_4', 'v_5')
53+
g2.add_edge('v_5', 'v_6')
54+
g2.add_edge('v_4', 'v_6')
55+
assert g2.is_adjacent('v_4', 'v_5') is True
56+
assert g2.is_adjacent('v_5', 'v_6') is True
57+
assert g2.is_adjacent('v_4', 'v_6') is True
58+
assert g2.is_adjacent('v_5', 'v_4') is False
59+
assert g2.is_adjacent('v_6', 'v_5') is False
60+
assert g2.is_adjacent('v_6', 'v_4') is False
61+
assert g2.num_edges() == 3
62+
assert g2.num_vertices() == 3
63+
neighbors = g2.neighbors('v_4')
64+
assert neighbors == [v_6, v_5]
65+
v = AdjacencyListGraphNode('v', 4, backend = Backend.CPP)
66+
g2.add_vertex(v)
67+
g2.add_edge('v_4', 'v', 0)
68+
g2.add_edge('v_5', 'v', 0)
69+
g2.add_edge('v_6', 'v', "h")
70+
assert g2.is_adjacent('v_4', 'v') is True
71+
assert g2.is_adjacent('v_5', 'v') is True
72+
assert g2.is_adjacent('v_6', 'v') is True
73+
e1 = g2.get_edge('v_4', 'v')
74+
e2 = g2.get_edge('v_5', 'v')
75+
e3 = g2.get_edge('v_6', 'v')
76+
assert (str(e1)) == "('v_4', 'v', 0)"
77+
assert (str(e2)) == "('v_5', 'v', 0)"
78+
assert (str(e3)) == "('v_6', 'v', h)"
79+
g2.remove_edge('v_4', 'v')
80+
assert g2.is_adjacent('v_4', 'v') is False
81+
g2.remove_vertex('v')
82+
assert raises(ValueError, lambda: g2.add_edge('v_4', 'v'))

0 commit comments

Comments
 (0)