Skip to content

Commit a2ba236

Browse files
authored
Added LLVM implementation of bubble sort (#693)
1 parent 79c4328 commit a2ba236

File tree

13 files changed

+966
-81
lines changed

13 files changed

+966
-81
lines changed

pydatastructs/graphs/_backend/cpp/Algorithms.hpp

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* a
1414
PyObject* operation;
1515
PyObject* varargs = nullptr;
1616
PyObject* kwargs_dict = nullptr;
17-
1817
static const char* kwlist[] = {"graph", "source_node", "operation", "args", "kwargs", nullptr};
18+
1919
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!sO|OO", const_cast<char**>(kwlist),
2020
&AdjacencyListGraphType, &graph_obj,
2121
&source_name, &operation,
@@ -24,54 +24,58 @@ static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* a
2424
}
2525

2626
AdjacencyListGraph* cpp_graph = reinterpret_cast<AdjacencyListGraph*>(graph_obj);
27-
2827
auto it = cpp_graph->node_map.find(source_name);
2928
AdjacencyListGraphNode* start_node = it->second;
30-
3129
std::unordered_set<std::string> visited;
3230
std::queue<AdjacencyListGraphNode*> q;
33-
3431
q.push(start_node);
3532
visited.insert(start_node->name);
3633

3734
while (!q.empty()) {
38-
AdjacencyListGraphNode* node = q.front();
39-
q.pop();
35+
AdjacencyListGraphNode* node = q.front();
36+
q.pop();
4037

41-
for (const auto& [adj_name, adj_obj] : node->adjacent) {
42-
if (visited.count(adj_name)) continue;
43-
if (get_type_tag(adj_obj) != NodeType_::AdjacencyListGraphNode) continue;
38+
for (const auto& [adj_name, adj_obj] : node->adjacent) {
39+
if (visited.count(adj_name)) continue;
40+
if (get_type_tag(adj_obj) != NodeType_::AdjacencyListGraphNode) continue;
4441

45-
AdjacencyListGraphNode* adj_node = reinterpret_cast<AdjacencyListGraphNode*>(adj_obj);
42+
AdjacencyListGraphNode* adj_node = reinterpret_cast<AdjacencyListGraphNode*>(adj_obj);
4643

47-
PyObject* base_args = PyTuple_Pack(2,
48-
reinterpret_cast<PyObject*>(node),
49-
reinterpret_cast<PyObject*>(adj_node));
50-
if (!base_args)
51-
return nullptr;
44+
PyObject* node_pyobj = reinterpret_cast<PyObject*>(node);
45+
PyObject* adj_node_pyobj = reinterpret_cast<PyObject*>(adj_node);
5246

53-
PyObject* final_args;
54-
if (varargs && PyTuple_Check(varargs)) {
55-
final_args = PySequence_Concat(base_args, varargs);
56-
Py_DECREF(base_args);
47+
PyObject* final_args;
48+
49+
if (varargs && PyTuple_Check(varargs)) {
50+
Py_ssize_t varargs_size = PyTuple_Size(varargs);
51+
if (varargs_size == 1) {
52+
PyObject* extra_arg = PyTuple_GetItem(varargs, 0);
53+
final_args = PyTuple_Pack(3, node_pyobj, adj_node_pyobj, extra_arg);
54+
} else {
55+
PyObject* base_args = PyTuple_Pack(2, node_pyobj, adj_node_pyobj);
56+
if (!base_args)
57+
return nullptr;
58+
final_args = PySequence_Concat(base_args, varargs);
59+
Py_DECREF(base_args);
60+
}
61+
} else {
62+
final_args = PyTuple_Pack(2, node_pyobj, adj_node_pyobj);
63+
}
5764
if (!final_args)
5865
return nullptr;
59-
} else {
60-
final_args = base_args;
61-
}
62-
63-
PyObject* result = PyObject_Call(operation, final_args, kwargs_dict);
64-
Py_DECREF(final_args);
6566

66-
if (!result)
67-
return nullptr;
67+
PyObject* result = PyObject_Call(operation, final_args, kwargs_dict);
68+
Py_DECREF(final_args);
6869

69-
Py_DECREF(result);
70+
if (!result)
71+
return nullptr;
7072

71-
visited.insert(adj_name);
72-
q.push(adj_node);
73-
}
73+
Py_DECREF(result);
74+
visited.insert(adj_name);
75+
q.push(adj_node);
76+
}
7477
}
78+
7579
if (PyErr_Occurred()) {
7680
return nullptr;
7781
}

pydatastructs/graphs/tests/test_algorithms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,17 @@ def bfs_tree(curr_node, next_node, parent):
4545
(parent[V3.name] == V2.name and parent[V2.name] == V1.name)
4646

4747
if (ds=='List'):
48-
parent2 = {}
48+
parent = {}
4949
V9 = AdjacencyListGraphNode("9",0,backend = Backend.CPP)
5050
V10 = AdjacencyListGraphNode("10",0,backend = Backend.CPP)
5151
V11 = AdjacencyListGraphNode("11",0,backend = Backend.CPP)
5252
G2 = Graph(V9, V10, V11,implementation = 'adjacency_list', backend = Backend.CPP)
5353
assert G2.num_vertices()==3
5454
G2.add_edge("9", "10")
5555
G2.add_edge("10", "11")
56-
breadth_first_search(G2, "9", bfs_tree, parent2, backend = Backend.CPP)
57-
assert parent2[V10] == V9
58-
assert parent2[V11] == V10
56+
breadth_first_search(G2, "9", bfs_tree, parent, backend = Backend.CPP)
57+
assert parent[V10] == V9
58+
assert parent[V11] == V10
5959

6060
if (ds == 'Matrix'):
6161
parent3 = {}

pydatastructs/linear_data_structures/_backend/cpp/algorithms/__init__.py

Whitespace-only changes.

pydatastructs/linear_data_structures/_backend/cpp/algorithms/algorithms.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ static PyMethodDef algorithms_PyMethodDef[] = {
88
METH_VARARGS | METH_KEYWORDS, ""},
99
{"bubble_sort", (PyCFunction) bubble_sort,
1010
METH_VARARGS | METH_KEYWORDS, ""},
11+
{"bubble_sort_llvm", (PyCFunction)bubble_sort_llvm,
12+
METH_VARARGS | METH_KEYWORDS, ""},
1113
{"selection_sort", (PyCFunction) selection_sort,
1214
METH_VARARGS | METH_KEYWORDS, ""},
1315
{"insertion_sort", (PyCFunction) insertion_sort,
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
from llvmlite import ir, binding
2+
import atexit
3+
4+
_SUPPORTED = {
5+
"int32": (ir.IntType(32), 4),
6+
"int64": (ir.IntType(64), 8),
7+
"float32": (ir.FloatType(), 4),
8+
"float64": (ir.DoubleType(), 8),
9+
}
10+
11+
_engines = {}
12+
_target_machine = None
13+
_fn_ptr_cache = {}
14+
15+
def _cleanup():
16+
"""Clean up LLVM resources on exit."""
17+
global _engines, _target_machine, _fn_ptr_cache
18+
_engines.clear()
19+
_target_machine = None
20+
_fn_ptr_cache.clear()
21+
22+
atexit.register(_cleanup)
23+
24+
def _ensure_target_machine():
25+
global _target_machine
26+
if _target_machine is not None:
27+
return
28+
29+
try:
30+
binding.initialize()
31+
binding.initialize_native_target()
32+
binding.initialize_native_asmprinter()
33+
34+
target = binding.Target.from_default_triple()
35+
_target_machine = target.create_target_machine()
36+
except Exception as e:
37+
raise RuntimeError(f"Failed to initialize LLVM target machine: {e}")
38+
39+
def get_bubble_sort_ptr(dtype: str) -> int:
40+
"""Get function pointer for bubble sort with specified dtype."""
41+
dtype = dtype.lower().strip()
42+
if dtype not in _SUPPORTED:
43+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
44+
45+
return _materialize(dtype)
46+
47+
def _build_bubble_sort_ir(dtype: str) -> str:
48+
if dtype not in _SUPPORTED:
49+
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
50+
51+
T, _ = _SUPPORTED[dtype]
52+
i32 = ir.IntType(32)
53+
i64 = ir.IntType(64)
54+
55+
mod = ir.Module(name=f"bubble_sort_{dtype}_module")
56+
fn_name = f"bubble_sort_{dtype}"
57+
58+
fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32])
59+
fn = ir.Function(mod, fn_ty, name=fn_name)
60+
61+
arr, n = fn.args
62+
arr.name, n.name = "arr", "n"
63+
64+
b_entry = fn.append_basic_block("entry")
65+
b_outer = fn.append_basic_block("outer")
66+
b_inner_init = fn.append_basic_block("inner.init")
67+
b_inner = fn.append_basic_block("inner")
68+
b_body = fn.append_basic_block("body")
69+
b_swap = fn.append_basic_block("swap")
70+
b_inner_latch = fn.append_basic_block("inner.latch")
71+
b_outer_latch = fn.append_basic_block("outer.latch")
72+
b_exit = fn.append_basic_block("exit")
73+
74+
b = ir.IRBuilder(b_entry)
75+
76+
cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1))
77+
b.cbranch(cond_trivial, b_exit, b_outer)
78+
79+
b.position_at_end(b_outer)
80+
i_phi = b.phi(i32, name="i")
81+
i_phi.add_incoming(ir.Constant(i32, 0), b_entry)
82+
83+
n1 = b.sub(n, ir.Constant(i32, 1), name="n_minus_1")
84+
cond_outer = b.icmp_signed("<", i_phi, n1)
85+
b.cbranch(cond_outer, b_inner_init, b_exit)
86+
87+
b.position_at_end(b_inner_init)
88+
89+
inner_limit = b.sub(n1, i_phi, name="inner_limit")
90+
b.branch(b_inner)
91+
92+
b.position_at_end(b_inner)
93+
j_phi = b.phi(i32, name="j")
94+
j_phi.add_incoming(ir.Constant(i32, 0), b_inner_init)
95+
96+
cond_inner = b.icmp_signed("<", j_phi, inner_limit)
97+
b.cbranch(cond_inner, b_body, b_outer_latch)
98+
99+
b.position_at_end(b_body)
100+
j64 = b.sext(j_phi, i64)
101+
jp1 = b.add(j_phi, ir.Constant(i32, 1))
102+
jp1_64 = b.sext(jp1, i64)
103+
104+
ptr_j = b.gep(arr, [j64], inbounds=True)
105+
ptr_jp1 = b.gep(arr, [jp1_64], inbounds=True)
106+
107+
val_j = b.load(ptr_j)
108+
val_jp1 = b.load(ptr_jp1)
109+
110+
if isinstance(T, ir.IntType):
111+
should_swap = b.icmp_signed(">", val_j, val_jp1)
112+
else:
113+
should_swap = b.fcmp_ordered(">", val_j, val_jp1)
114+
115+
b.cbranch(should_swap, b_swap, b_inner_latch)
116+
117+
b.position_at_end(b_swap)
118+
b.store(val_jp1, ptr_j)
119+
b.store(val_j, ptr_jp1)
120+
b.branch(b_inner_latch)
121+
122+
b.position_at_end(b_inner_latch)
123+
j_next = b.add(j_phi, ir.Constant(i32, 1))
124+
j_phi.add_incoming(j_next, b_inner_latch)
125+
b.branch(b_inner)
126+
127+
b.position_at_end(b_outer_latch)
128+
i_next = b.add(i_phi, ir.Constant(i32, 1))
129+
i_phi.add_incoming(i_next, b_outer_latch)
130+
b.branch(b_outer)
131+
132+
b.position_at_end(b_exit)
133+
b.ret_void()
134+
135+
return str(mod)
136+
137+
def _materialize(dtype: str) -> int:
138+
_ensure_target_machine()
139+
140+
if dtype in _fn_ptr_cache:
141+
return _fn_ptr_cache[dtype]
142+
143+
try:
144+
llvm_ir = _build_bubble_sort_ir(dtype)
145+
mod = binding.parse_assembly(llvm_ir)
146+
mod.verify()
147+
148+
engine = binding.create_mcjit_compiler(mod, _target_machine)
149+
engine.finalize_object()
150+
engine.run_static_constructors()
151+
152+
addr = engine.get_function_address(f"bubble_sort_{dtype}")
153+
if not addr:
154+
raise RuntimeError(f"Failed to get address for bubble_sort_{dtype}")
155+
156+
_fn_ptr_cache[dtype] = addr
157+
_engines[dtype] = engine
158+
159+
return addr
160+
161+
except Exception as e:
162+
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")

0 commit comments

Comments
 (0)