Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions pydatastructs/graphs/_backend/cpp/Algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* a
PyObject* operation;
PyObject* varargs = nullptr;
PyObject* kwargs_dict = nullptr;

static const char* kwlist[] = {"graph", "source_node", "operation", "args", "kwargs", nullptr};

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!sO|OO", const_cast<char**>(kwlist),
&AdjacencyListGraphType, &graph_obj,
&source_name, &operation,
Expand All @@ -24,54 +24,58 @@ static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* a
}

AdjacencyListGraph* cpp_graph = reinterpret_cast<AdjacencyListGraph*>(graph_obj);

auto it = cpp_graph->node_map.find(source_name);
AdjacencyListGraphNode* start_node = it->second;

std::unordered_set<std::string> visited;
std::queue<AdjacencyListGraphNode*> q;

q.push(start_node);
visited.insert(start_node->name);

while (!q.empty()) {
AdjacencyListGraphNode* node = q.front();
q.pop();
AdjacencyListGraphNode* node = q.front();
q.pop();

for (const auto& [adj_name, adj_obj] : node->adjacent) {
if (visited.count(adj_name)) continue;
if (get_type_tag(adj_obj) != NodeType_::AdjacencyListGraphNode) continue;
for (const auto& [adj_name, adj_obj] : node->adjacent) {
if (visited.count(adj_name)) continue;
if (get_type_tag(adj_obj) != NodeType_::AdjacencyListGraphNode) continue;

AdjacencyListGraphNode* adj_node = reinterpret_cast<AdjacencyListGraphNode*>(adj_obj);
AdjacencyListGraphNode* adj_node = reinterpret_cast<AdjacencyListGraphNode*>(adj_obj);

PyObject* base_args = PyTuple_Pack(2,
reinterpret_cast<PyObject*>(node),
reinterpret_cast<PyObject*>(adj_node));
if (!base_args)
return nullptr;
PyObject* node_pyobj = reinterpret_cast<PyObject*>(node);
PyObject* adj_node_pyobj = reinterpret_cast<PyObject*>(adj_node);

PyObject* final_args;
if (varargs && PyTuple_Check(varargs)) {
final_args = PySequence_Concat(base_args, varargs);
Py_DECREF(base_args);
PyObject* final_args;

if (varargs && PyTuple_Check(varargs)) {
Py_ssize_t varargs_size = PyTuple_Size(varargs);
if (varargs_size == 1) {
PyObject* extra_arg = PyTuple_GetItem(varargs, 0);
final_args = PyTuple_Pack(3, node_pyobj, adj_node_pyobj, extra_arg);
} else {
PyObject* base_args = PyTuple_Pack(2, node_pyobj, adj_node_pyobj);
if (!base_args)
return nullptr;
final_args = PySequence_Concat(base_args, varargs);
Py_DECREF(base_args);
}
} else {
final_args = PyTuple_Pack(2, node_pyobj, adj_node_pyobj);
}
if (!final_args)
return nullptr;
} else {
final_args = base_args;
}

PyObject* result = PyObject_Call(operation, final_args, kwargs_dict);
Py_DECREF(final_args);

if (!result)
return nullptr;
PyObject* result = PyObject_Call(operation, final_args, kwargs_dict);
Py_DECREF(final_args);

Py_DECREF(result);
if (!result)
return nullptr;

visited.insert(adj_name);
q.push(adj_node);
}
Py_DECREF(result);
visited.insert(adj_name);
q.push(adj_node);
}
}

if (PyErr_Occurred()) {
return nullptr;
}
Expand Down
8 changes: 4 additions & 4 deletions pydatastructs/graphs/tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ def bfs_tree(curr_node, next_node, parent):
(parent[V3.name] == V2.name and parent[V2.name] == V1.name)

if (ds=='List'):
parent2 = {}
parent = {}
V9 = AdjacencyListGraphNode("9",0,backend = Backend.CPP)
V10 = AdjacencyListGraphNode("10",0,backend = Backend.CPP)
V11 = AdjacencyListGraphNode("11",0,backend = Backend.CPP)
G2 = Graph(V9, V10, V11,implementation = 'adjacency_list', backend = Backend.CPP)
assert G2.num_vertices()==3
G2.add_edge("9", "10")
G2.add_edge("10", "11")
breadth_first_search(G2, "9", bfs_tree, parent2, backend = Backend.CPP)
assert parent2[V10] == V9
assert parent2[V11] == V10
breadth_first_search(G2, "9", bfs_tree, parent, backend = Backend.CPP)
assert parent[V10] == V9
assert parent[V11] == V10

if (ds == 'Matrix'):
parent3 = {}
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ static PyMethodDef algorithms_PyMethodDef[] = {
METH_VARARGS | METH_KEYWORDS, ""},
{"bubble_sort", (PyCFunction) bubble_sort,
METH_VARARGS | METH_KEYWORDS, ""},
{"bubble_sort_llvm", (PyCFunction)bubble_sort_llvm,
METH_VARARGS | METH_KEYWORDS, ""},
{"selection_sort", (PyCFunction) selection_sort,
METH_VARARGS | METH_KEYWORDS, ""},
{"insertion_sort", (PyCFunction) insertion_sort,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from llvmlite import ir, binding
import atexit

_SUPPORTED = {
"int32": (ir.IntType(32), 4),
"int64": (ir.IntType(64), 8),
"float32": (ir.FloatType(), 4),
"float64": (ir.DoubleType(), 8),
}

_engines = {}
_target_machine = None
_fn_ptr_cache = {}

def _cleanup():
"""Clean up LLVM resources on exit."""
global _engines, _target_machine, _fn_ptr_cache
_engines.clear()
_target_machine = None
_fn_ptr_cache.clear()

atexit.register(_cleanup)

def _ensure_target_machine():
global _target_machine
if _target_machine is not None:
return

try:
binding.initialize()
binding.initialize_native_target()
binding.initialize_native_asmprinter()

target = binding.Target.from_default_triple()
_target_machine = target.create_target_machine()
except Exception as e:
raise RuntimeError(f"Failed to initialize LLVM target machine: {e}")

def get_bubble_sort_ptr(dtype: str) -> int:
"""Get function pointer for bubble sort with specified dtype."""
dtype = dtype.lower().strip()
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

return _materialize(dtype)

def _build_bubble_sort_ir(dtype: str) -> str:
if dtype not in _SUPPORTED:
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")

T, _ = _SUPPORTED[dtype]
i32 = ir.IntType(32)
i64 = ir.IntType(64)

mod = ir.Module(name=f"bubble_sort_{dtype}_module")
fn_name = f"bubble_sort_{dtype}"

fn_ty = ir.FunctionType(ir.VoidType(), [T.as_pointer(), i32])
fn = ir.Function(mod, fn_ty, name=fn_name)

arr, n = fn.args
arr.name, n.name = "arr", "n"

b_entry = fn.append_basic_block("entry")
b_outer = fn.append_basic_block("outer")
b_inner_init = fn.append_basic_block("inner.init")
b_inner = fn.append_basic_block("inner")
b_body = fn.append_basic_block("body")
b_swap = fn.append_basic_block("swap")
b_inner_latch = fn.append_basic_block("inner.latch")
b_outer_latch = fn.append_basic_block("outer.latch")
b_exit = fn.append_basic_block("exit")

b = ir.IRBuilder(b_entry)

cond_trivial = b.icmp_signed("<=", n, ir.Constant(i32, 1))
b.cbranch(cond_trivial, b_exit, b_outer)

b.position_at_end(b_outer)
i_phi = b.phi(i32, name="i")
i_phi.add_incoming(ir.Constant(i32, 0), b_entry)

n1 = b.sub(n, ir.Constant(i32, 1), name="n_minus_1")
cond_outer = b.icmp_signed("<", i_phi, n1)
b.cbranch(cond_outer, b_inner_init, b_exit)

b.position_at_end(b_inner_init)

inner_limit = b.sub(n1, i_phi, name="inner_limit")
b.branch(b_inner)

b.position_at_end(b_inner)
j_phi = b.phi(i32, name="j")
j_phi.add_incoming(ir.Constant(i32, 0), b_inner_init)

cond_inner = b.icmp_signed("<", j_phi, inner_limit)
b.cbranch(cond_inner, b_body, b_outer_latch)

b.position_at_end(b_body)
j64 = b.sext(j_phi, i64)
jp1 = b.add(j_phi, ir.Constant(i32, 1))
jp1_64 = b.sext(jp1, i64)

ptr_j = b.gep(arr, [j64], inbounds=True)
ptr_jp1 = b.gep(arr, [jp1_64], inbounds=True)

val_j = b.load(ptr_j)
val_jp1 = b.load(ptr_jp1)

if isinstance(T, ir.IntType):
should_swap = b.icmp_signed(">", val_j, val_jp1)
else:
should_swap = b.fcmp_ordered(">", val_j, val_jp1)

b.cbranch(should_swap, b_swap, b_inner_latch)

b.position_at_end(b_swap)
b.store(val_jp1, ptr_j)
b.store(val_j, ptr_jp1)
b.branch(b_inner_latch)

b.position_at_end(b_inner_latch)
j_next = b.add(j_phi, ir.Constant(i32, 1))
j_phi.add_incoming(j_next, b_inner_latch)
b.branch(b_inner)

b.position_at_end(b_outer_latch)
i_next = b.add(i_phi, ir.Constant(i32, 1))
i_phi.add_incoming(i_next, b_outer_latch)
b.branch(b_outer)

b.position_at_end(b_exit)
b.ret_void()

return str(mod)

def _materialize(dtype: str) -> int:
_ensure_target_machine()

if dtype in _fn_ptr_cache:
return _fn_ptr_cache[dtype]

try:
llvm_ir = _build_bubble_sort_ir(dtype)
mod = binding.parse_assembly(llvm_ir)
mod.verify()

engine = binding.create_mcjit_compiler(mod, _target_machine)
engine.finalize_object()
engine.run_static_constructors()

addr = engine.get_function_address(f"bubble_sort_{dtype}")
if not addr:
raise RuntimeError(f"Failed to get address for bubble_sort_{dtype}")

_fn_ptr_cache[dtype] = addr
_engines[dtype] = engine

return addr

except Exception as e:
raise RuntimeError(f"Failed to materialize function for dtype {dtype}: {e}")
Loading
Loading