Skip to content

Commit de01034

Browse files
Prerak SinghPrerak Singh
authored andcommitted
added support for pyobject
1 parent 6d81439 commit de01034

File tree

2 files changed

+65
-28
lines changed

2 files changed

+65
-28
lines changed

pydatastructs/utils/_backend/cpp/AdjacencyListGraphNode.hpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,29 @@ extern PyTypeObject AdjacencyListGraphNodeType;
1313
typedef struct {
1414
PyObject_HEAD
1515
std::string name;
16-
std::variant<std::monostate, int64_t, double, std::string> data;
16+
std::variant<std::monostate, int64_t, double, std::string, PyObject*> data;
1717
DataType data_type;
1818
std::unordered_map<std::string, PyObject*> adjacent;
1919
} AdjacencyListGraphNode;
2020

2121
static void AdjacencyListGraphNode_dealloc(AdjacencyListGraphNode* self) {
22+
if (self->data_type == DataType::PyObject) {
23+
Py_XDECREF(std::get<PyObject*>(self->data));
24+
}
25+
2226
for (auto& pair : self->adjacent) {
2327
Py_XDECREF(pair.second);
2428
}
2529
self->adjacent.clear();
26-
Py_TYPE(self)->tp_free(reinterpret_cast<PyTypeObject*>(self));
30+
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
2731
}
2832

2933
static PyObject* AdjacencyListGraphNode_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
3034
AdjacencyListGraphNode* self = PyObject_New(AdjacencyListGraphNode, &AdjacencyListGraphNodeType);
3135
if (!self) return NULL;
3236
new (&self->adjacent) std::unordered_map<std::string, PyObject*>();
3337
new (&self->name) std::string();
34-
new (&self->data) std::variant<std::monostate, int64_t, double, std::string>();
38+
new (&self->data) std::variant<std::monostate, int64_t, double, std::string, PyObject*>();
3539
self->data_type = DataType::None;
3640
self->data = std::monostate{};
3741

@@ -61,16 +65,16 @@ static PyObject* AdjacencyListGraphNode_new(PyTypeObject* type, PyObject* args,
6165
self->data_type = DataType::String;
6266
self->data = std::string(str);
6367
} else {
64-
PyErr_SetString(PyExc_TypeError, "Unsupported data type. Must be int, float, str, or None.");
65-
return NULL;
68+
self->data_type = DataType::PyObject;
69+
Py_INCREF(data);
70+
self->data = data;
6671
}
6772

6873
if (PyList_Check(adjacency_list)) {
6974
Py_ssize_t size = PyList_Size(adjacency_list);
7075
for (Py_ssize_t i = 0; i < size; i++) {
7176
PyObject* node = PyList_GetItem(adjacency_list, i);
7277

73-
7478
if (PyType_Ready(&AdjacencyListGraphNodeType) < 0) {
7579
PyErr_SetString(PyExc_RuntimeError, "Failed to initialize AdjacencyListGraphNodeType");
7680
return NULL;
@@ -85,7 +89,7 @@ static PyObject* AdjacencyListGraphNode_new(PyTypeObject* type, PyObject* args,
8589
std::string str = std::string(adj_name);
8690
Py_INCREF(node);
8791
self->adjacent[str] = node;
88-
}
92+
}
8993
}
9094

9195
return reinterpret_cast<PyObject*>(self);
@@ -154,13 +158,20 @@ static PyObject* AdjacencyListGraphNode_get_data(AdjacencyListGraphNode* self, v
154158
return PyFloat_FromDouble(std::get<double>(self->data));
155159
case DataType::String:
156160
return PyUnicode_FromString(std::get<std::string>(self->data).c_str());
161+
case DataType::PyObject:
162+
Py_INCREF(std::get<PyObject*>(self->data));
163+
return std::get<PyObject*>(self->data);
157164
case DataType::None:
158165
default:
159166
Py_RETURN_NONE;
160167
}
161168
}
162169

163170
static int AdjacencyListGraphNode_set_data(AdjacencyListGraphNode* self, PyObject* value, void* closure) {
171+
if (self->data_type == DataType::PyObject) {
172+
Py_XDECREF(std::get<PyObject*>(self->data));
173+
}
174+
164175
if (value == Py_None) {
165176
self->data_type = DataType::None;
166177
self->data = std::monostate{};
@@ -179,8 +190,9 @@ static int AdjacencyListGraphNode_set_data(AdjacencyListGraphNode* self, PyObjec
179190
self->data_type = DataType::String;
180191
self->data = std::string(str);
181192
} else {
182-
PyErr_SetString(PyExc_TypeError, "Unsupported data type. Must be int, float, str, or None.");
183-
return -1;
193+
self->data_type = DataType::PyObject;
194+
Py_INCREF(value);
195+
self->data = value;
184196
}
185197
return 0;
186198
}
@@ -227,7 +239,6 @@ static PyGetSetDef AdjacencyListGraphNode_getsetters[] = {
227239
{NULL}
228240
};
229241

230-
231242
static PyMethodDef AdjacencyListGraphNode_methods[] = {
232243
{"add_adjacent_node", (PyCFunction)AdjacencyListGraphNode_add_adjacent_node, METH_VARARGS, "Add adjacent node"},
233244
{"remove_adjacent_node", (PyCFunction)AdjacencyListGraphNode_remove_adjacent_node, METH_VARARGS, "Remove adjacent node"},

pydatastructs/utils/_backend/cpp/GraphNode.hpp

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,31 @@ enum class DataType {
1010
None,
1111
Int,
1212
Double,
13-
String
13+
String,
14+
PyObject
1415
};
1516

1617
typedef struct {
1718
PyObject_HEAD
1819
std::string name;
19-
std::variant<std::monostate, int64_t, double, std::string> data;
20+
std::variant<std::monostate, int64_t, double, std::string, PyObject*> data;
2021
DataType data_type;
2122
} GraphNode;
2223

23-
static void GraphNode_dealloc(GraphNode* self){
24-
Py_TYPE(self)->tp_free(reinterpret_cast<PyTypeObject*>(self));
24+
static void GraphNode_dealloc(GraphNode* self) {
25+
if (self->data_type == DataType::PyObject) {
26+
Py_XDECREF(std::get<PyObject*>(self->data));
27+
}
28+
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
2529
}
2630

27-
static PyObject* GraphNode_new(PyTypeObject* type, PyObject* args, PyObject* kwds){
28-
GraphNode* self;
29-
self = reinterpret_cast<GraphNode*>(type->tp_alloc(type,0));
31+
static PyObject* GraphNode_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
32+
GraphNode* self = reinterpret_cast<GraphNode*>(type->tp_alloc(type, 0));
33+
if (!self) return NULL;
34+
3035
new (&self->name) std::string();
31-
new (&self->data) std::variant<std::monostate, int64_t, double, std::string>();
36+
new (&self->data) std::variant<std::monostate, int64_t, double, std::string, PyObject*>();
3237
self->data_type = DataType::None;
33-
if (!self) return NULL;
3438

3539
static char* kwlist[] = { "name", "data", NULL };
3640
const char* name;
@@ -57,8 +61,9 @@ static PyObject* GraphNode_new(PyTypeObject* type, PyObject* args, PyObject* kwd
5761
self->data = std::string(s);
5862
self->data_type = DataType::String;
5963
} else {
60-
PyErr_SetString(PyExc_TypeError, "data must be int, float, str, or None");
61-
return NULL;
64+
self->data = data;
65+
self->data_type = DataType::PyObject;
66+
Py_INCREF(data);
6267
}
6368

6469
return reinterpret_cast<PyObject*>(self);
@@ -80,15 +85,28 @@ static PyObject* GraphNode_str(GraphNode* self) {
8085
case DataType::String:
8186
repr += "'" + std::get<std::string>(self->data) + "'";
8287
break;
88+
case DataType::PyObject: {
89+
PyObject* repr_obj = PyObject_Repr(std::get<PyObject*>(self->data));
90+
if (repr_obj) {
91+
const char* repr_cstr = PyUnicode_AsUTF8(repr_obj);
92+
repr += repr_cstr ? repr_cstr : "<unprintable>";
93+
Py_DECREF(repr_obj);
94+
} else {
95+
repr += "<error in repr>";
96+
}
97+
break;
98+
}
8399
}
100+
84101
repr += ")";
85102
return PyUnicode_FromString(repr.c_str());
86103
}
87104

88105
static PyObject* GraphNode_get(GraphNode* self, void *closure) {
89-
if (closure == (void*)"name") {
106+
const char* attr = reinterpret_cast<const char*>(closure);
107+
if (strcmp(attr, "name") == 0) {
90108
return PyUnicode_FromString(self->name.c_str());
91-
} else if (closure == (void*)"data") {
109+
} else if (strcmp(attr, "data") == 0) {
92110
switch (self->data_type) {
93111
case DataType::None:
94112
Py_RETURN_NONE;
@@ -98,25 +116,32 @@ static PyObject* GraphNode_get(GraphNode* self, void *closure) {
98116
return PyFloat_FromDouble(std::get<double>(self->data));
99117
case DataType::String:
100118
return PyUnicode_FromString(std::get<std::string>(self->data).c_str());
119+
case DataType::PyObject:
120+
Py_INCREF(std::get<PyObject*>(self->data));
121+
return std::get<PyObject*>(self->data);
101122
}
102123
}
103124
Py_RETURN_NONE;
104125
}
105126

106127
static int GraphNode_set(GraphNode* self, PyObject *value, void *closure) {
128+
const char* attr = reinterpret_cast<const char*>(closure);
107129
if (!value) {
108130
PyErr_SetString(PyExc_ValueError, "Cannot delete attributes");
109131
return -1;
110132
}
111133

112-
if (closure == (void*)"name") {
134+
if (strcmp(attr, "name") == 0) {
113135
if (!PyUnicode_Check(value)) {
114136
PyErr_SetString(PyExc_TypeError, "name must be a string");
115137
return -1;
116138
}
117139
self->name = PyUnicode_AsUTF8(value);
118-
}
119-
else if (closure == (void*)"data") {
140+
} else if (strcmp(attr, "data") == 0) {
141+
if (self->data_type == DataType::PyObject) {
142+
Py_XDECREF(std::get<PyObject*>(self->data));
143+
}
144+
120145
if (value == Py_None) {
121146
self->data = std::monostate{};
122147
self->data_type = DataType::None;
@@ -130,8 +155,9 @@ static int GraphNode_set(GraphNode* self, PyObject *value, void *closure) {
130155
self->data = std::string(PyUnicode_AsUTF8(value));
131156
self->data_type = DataType::String;
132157
} else {
133-
PyErr_SetString(PyExc_TypeError, "data must be int, float, str, or None");
134-
return -1;
158+
Py_INCREF(value);
159+
self->data = value;
160+
self->data_type = DataType::PyObject;
135161
}
136162
} else {
137163
PyErr_SetString(PyExc_AttributeError, "Unknown attribute");

0 commit comments

Comments
 (0)