Skip to content

Commit 79a8c57

Browse files
committed
finish refactor and add test
1 parent fa60555 commit 79a8c57

File tree

2 files changed

+62
-47
lines changed

2 files changed

+62
-47
lines changed

bson/_cbsonmodule.c

Lines changed: 34 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,59 +1649,48 @@ static int write_raw_doc(buffer_t buffer, PyObject* raw, PyObject* _raw_str) {
16491649
*/
16501650
void handle_invalid_doc_error(PyObject* dict) {
16511651
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
1652+
PyObject *msg = NULL, *dict_str = NULL, *new_msg = NULL;
16521653
PyErr_Fetch(&etype, &evalue, &etrace);
16531654
PyObject *InvalidDocument = _error("InvalidDocument");
16541655
if (InvalidDocument == NULL) {
1655-
PyErr_Restore(etype, evalue, etrace);
1656-
return;
1656+
goto cleanup;
16571657
}
16581658

1659-
if (PyErr_GivenExceptionMatches(etype, InvalidDocument)) {
1660-
1661-
Py_DECREF(etype);
1662-
etype = InvalidDocument;
1663-
1664-
if (evalue) {
1665-
PyObject *msg = PyObject_Str(evalue);
1659+
if (evalue && PyErr_GivenExceptionMatches(etype, InvalidDocument)) {
1660+
PyObject *msg = PyObject_Str(evalue);
1661+
if (msg) {
1662+
// Prepend doc to the existing message
1663+
PyObject *dict_str = PyObject_Str(dict);
1664+
if (dict_str == NULL) {
1665+
goto cleanup;
1666+
}
1667+
const char * dict_str_utf8 = PyUnicode_AsUTF8(dict_str);
1668+
if (dict_str_utf8 == NULL) {
1669+
goto cleanup;
1670+
}
1671+
const char * msg_utf8 = PyUnicode_AsUTF8(msg);
1672+
if (msg_utf8 == NULL) {
1673+
goto cleanup;
1674+
}
1675+
PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", dict_str_utf8, msg_utf8);
16661676
Py_DECREF(evalue);
1667-
1668-
if (msg) {
1669-
// Prepend doc to the existing message
1670-
PyObject *dict_str = PyObject_Str(dict);
1671-
if (dict_str == NULL) {
1672-
Py_DECREF(msg);
1673-
return;
1674-
}
1675-
const char * dict_str_utf8 = PyUnicode_AsUTF8(dict);
1676-
Py_DECREF(dict_str);
1677-
if (dict_str_utf8 == NULL) {
1678-
Py_DECREF(msg);
1679-
return;
1680-
}
1681-
const char * msg_utf8 = PyUnicode_AsUTF8(msg);
1682-
if (msg_utf8 == NULL) {
1683-
Py_DECREF(msg);
1684-
return;
1685-
}
1686-
PyObject *new_msg = PyUnicode_FromFormat("Invalid document %s | %s", dict_str_utf8, msg_utf8);
1687-
Py_DECREF(msg_utf8);
1688-
1689-
if (new_msg) {
1690-
Py_DECREF(msg);
1691-
evalue = new_msg;
1692-
}
1693-
else {
1694-
Py_DECREF(new_msg);
1695-
evalue = msg;
1696-
}
1677+
Py_DECREF(etype);
1678+
etype = InvalidDocument;
1679+
InvalidDocument = NULL;
1680+
if (new_msg) {
1681+
evalue = new_msg;
1682+
} else {
1683+
evalue = msg;
16971684
}
16981685
}
16991686
PyErr_NormalizeException(&etype, &evalue, &etrace);
17001687
}
1701-
else {
1702-
Py_DECREF(InvalidDocument);
1703-
}
1688+
cleanup:
17041689
PyErr_Restore(etype, evalue, etrace);
1690+
Py_XDECREF(msg);
1691+
Py_XDECREF(InvalidDocument);
1692+
Py_XDECREF(dict_str);
1693+
Py_XDECREF(new_msg);
17051694
}
17061695

17071696

@@ -1804,8 +1793,6 @@ int write_dict(PyObject* self, buffer_t buffer,
18041793
while (PyDict_Next(dict, &pos, &key, &value)) {
18051794
if (!decode_and_write_pair(self, buffer, key, value,
18061795
check_keys, options, top_level)) {
1807-
Py_DECREF(key);
1808-
Py_DECREF(value);
18091796
if (PyErr_Occurred() && top_level) {
18101797
handle_invalid_doc_error(dict);
18111798
}
@@ -1827,12 +1814,12 @@ int write_dict(PyObject* self, buffer_t buffer,
18271814
}
18281815
if (!decode_and_write_pair(self, buffer, key, value,
18291816
check_keys, options, top_level)) {
1830-
Py_DECREF(key);
1831-
Py_DECREF(value);
1832-
Py_DECREF(iter);
18331817
if (PyErr_Occurred() && top_level) {
18341818
handle_invalid_doc_error(dict);
18351819
}
1820+
Py_DECREF(key);
1821+
Py_DECREF(value);
1822+
Py_DECREF(iter);
18361823
return 0;
18371824
}
18381825
Py_DECREF(key);

test/test_bson.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,34 @@ def __repr__(self):
11121112
with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"):
11131113
encode(doc)
11141114

1115+
def test_doc_in_invalid_document_error_message_mapping(self):
1116+
class MyMapping(abc.Mapping):
1117+
def keys():
1118+
return ["t"]
1119+
1120+
def __getitem__(self, name):
1121+
if name == "_id":
1122+
return None
1123+
return Wrapper(name)
1124+
1125+
def __len__(self):
1126+
return 1
1127+
1128+
def __iter__(self):
1129+
return iter(["t"])
1130+
1131+
class Wrapper:
1132+
def __init__(self, val):
1133+
self.val = val
1134+
1135+
def __repr__(self):
1136+
return repr(self.val)
1137+
1138+
self.assertEqual("1", repr(Wrapper(1)))
1139+
doc = MyMapping()
1140+
with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"):
1141+
encode(doc)
1142+
11151143

11161144
class TestCodecOptions(unittest.TestCase):
11171145
def test_document_class(self):

0 commit comments

Comments
 (0)