Skip to content
277 changes: 145 additions & 132 deletions bson/_cbsonmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ struct module_state {
PyObject* Decimal128;
PyObject* Mapping;
PyObject* DatetimeMS;
PyObject* _min_datetime_ms;
PyObject* _max_datetime_ms;
PyObject* _type_marker_str;
PyObject* _flags_str;
PyObject* _pattern_str;
Expand All @@ -80,6 +78,10 @@ struct module_state {
PyObject* _from_uuid_str;
PyObject* _as_uuid_str;
PyObject* _from_bid_str;
PyObject* min_datetime;
PyObject* max_datetime;
int64_t min_millis;
int64_t max_millis;
};

#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
Expand Down Expand Up @@ -253,7 +255,7 @@ static PyObject* datetime_from_millis(long long millis) {
* 2. Multiply that by 1000: 253402300799000
* 3. Add in microseconds divided by 1000 253402300799999
*
* (Note: BSON doesn't support microsecond accuracy, hence the rounding.)
* (Note: BSON doesn't support microsecond accuracy, hence the truncation.)
*
* To decode we could do:
* 1. Get seconds: timestamp / 1000: 253402300799
Expand Down Expand Up @@ -376,6 +378,110 @@ static int millis_from_datetime_ms(PyObject* dt, long long* out){
return 1;
}

static PyObject* decode_datetime(PyObject* self, long long millis, const codec_options_t* options){
PyObject* naive = NULL;
PyObject* replace = NULL;
PyObject* args = NULL;
PyObject* kwargs = NULL;
PyObject* value = NULL;
struct module_state *state = GETSTATE(self);
if (options->datetime_conversion == DATETIME_MS){
return datetime_ms_from_millis(self, millis);
}

int dt_clamp = options->datetime_conversion == DATETIME_CLAMP;
int dt_auto = options->datetime_conversion == DATETIME_AUTO;

if (dt_clamp || dt_auto){
if (dt_clamp) {
if (millis < state->min_millis) {
millis = state->min_millis;
} else if (millis > state->max_millis) {
millis = state->max_millis;
}
// Continues from here to return a datetime.
} else { // dt_auto
if (millis < state->min_millis || millis > state->max_millis){
return datetime_ms_from_millis(self, millis);
}
}
}

naive = datetime_from_millis(millis);
if (!naive) {
goto invalid;
}

if (!options->tz_aware) { /* In the naive case, we're done here. */
return naive;
}
replace = PyObject_GetAttr(naive, state->_replace_str);
if (!replace) {
goto invalid;
}
args = PyTuple_New(0);
if (!args) {
goto invalid;
}
kwargs = PyDict_New();
if (!kwargs) {
goto invalid;
}
if (PyDict_SetItem(kwargs, state->_tzinfo_str, state->UTC) == -1) {
goto invalid;
}
value = PyObject_Call(replace, args, kwargs);
if (!value) {
goto invalid;
}

/* convert to local time */
if (options->tzinfo != Py_None) {
PyObject* temp = PyObject_CallMethodObjArgs(value, state->_astimezone_str, options->tzinfo, NULL);
Py_DECREF(value);
value = temp;
if (!value && (dt_clamp || dt_auto)) {
PyObject *etype = NULL, *evalue = NULL, *etrace = NULL;
// Calling PyErr_Fetch clears the error state.
PyErr_Fetch(&etype, &evalue, &etrace);
// Catch overflow due to timezone via PyExc_ArithmeticError exceptions.
if (!PyErr_GivenExceptionMatches(etype, PyExc_ArithmeticError)) {
// Steals references to args.
PyErr_Restore(etype, evalue, etrace);
goto invalid;
}
Py_XDECREF(etype);
Py_XDECREF(evalue);
Py_XDECREF(etrace);
if (dt_clamp) {
PyObject* dtm;
Py_XDECREF(replace);
if (millis < 0) {
dtm = state->min_datetime;
} else {
dtm = state->max_datetime;
}
if (PyDict_SetItem(kwargs, state->_tzinfo_str, options->tzinfo) == -1) {
goto invalid;
}
replace = PyObject_GetAttr(dtm, state->_replace_str);
if (!replace) {
goto invalid;
}
value = PyObject_Call(replace, args, kwargs);
} else { // dt_auto
value = datetime_ms_from_millis(self, millis);
}
}
}
invalid:
Py_XDECREF(naive);
Py_XDECREF(replace);
Py_XDECREF(args);
Py_XDECREF(kwargs);
return value;
}

/* Just make this compatible w/ the old API. */
int buffer_write_bytes(buffer_t buffer, const char* data, int size) {
if (pymongo_buffer_write(buffer, data, size)) {
Expand Down Expand Up @@ -482,6 +588,8 @@ static int _load_python_objects(PyObject* module) {
PyObject* empty_string = NULL;
PyObject* re_compile = NULL;
PyObject* compiled = NULL;
PyObject* min_datetime_ms = NULL;
PyObject* max_datetime_ms = NULL;
struct module_state *state = GETSTATE(module);
if (!state) {
return 1;
Expand Down Expand Up @@ -530,10 +638,21 @@ static int _load_python_objects(PyObject* module) {
_load_object(&state->UUID, "uuid", "UUID") ||
_load_object(&state->Mapping, "collections.abc", "Mapping") ||
_load_object(&state->DatetimeMS, "bson.datetime_ms", "DatetimeMS") ||
_load_object(&state->_min_datetime_ms, "bson.datetime_ms", "_min_datetime_ms") ||
_load_object(&state->_max_datetime_ms, "bson.datetime_ms", "_max_datetime_ms")) {
_load_object(&min_datetime_ms, "bson.datetime_ms", "_MIN_DATETIME_MS") ||
_load_object(&max_datetime_ms, "bson.datetime_ms", "_MAX_DATETIME_MS") ||
_load_object(&state->min_datetime, "bson.datetime_ms", "_MIN_DATETIME") ||
_load_object(&state->max_datetime, "bson.datetime_ms", "_MAX_DATETIME")) {
return 1;
}

state->min_millis = PyLong_AsLongLong(min_datetime_ms);
state->max_millis = PyLong_AsLongLong(max_datetime_ms);
Py_DECREF(min_datetime_ms);
Py_DECREF(max_datetime_ms);
if ((state->min_millis == -1 || state->max_millis == -1) && PyErr_Occurred()) {
return 1;
}

/* Reload our REType hack too. */
empty_string = PyBytes_FromString("");
if (empty_string == NULL) {
Expand Down Expand Up @@ -1241,21 +1360,29 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x02;
return write_unicode(buffer, value);
} else if (PyDateTime_Check(value)) {
long long millis;
long long millis = millis_from_datetime(value);
PyObject* utcoffset = PyObject_CallMethodObjArgs(value, state->_utcoffset_str , NULL);
if (utcoffset == NULL)
return 0;
if (utcoffset != Py_None) {
PyObject* result = PyNumber_Subtract(value, utcoffset);
Py_DECREF(utcoffset);
if (!result) {
if (!PyDelta_Check(utcoffset)) {
PyObject* BSONError = _error("BSONError");
if (BSONError) {
PyErr_SetString(BSONError,
"datetime.utcoffset() did not return a datetime.timedelta");
Py_DECREF(BSONError);
}
Py_DECREF(utcoffset);
return 0;
}
millis = millis_from_datetime(result);
Py_DECREF(result);
} else {
millis = millis_from_datetime(value);
PyDateTime_DELTA_GET_DAYS(utcoffset);
PyDateTime_DELTA_GET_SECONDS(utcoffset);
PyDateTime_DELTA_GET_MICROSECONDS(utcoffset);
millis -= (PyDateTime_DELTA_GET_DAYS(utcoffset) * 86400 +
PyDateTime_DELTA_GET_SECONDS(utcoffset)) * 1000 +
(PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000);
}
Py_DECREF(utcoffset);
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x09;
return buffer_write_int64(buffer, (int64_t)millis);
} else if (PyObject_TypeCheck(value, state->REType)) {
Expand Down Expand Up @@ -2043,11 +2170,6 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
case 9:
{
PyObject* naive;
PyObject* replace;
PyObject* args;
PyObject* kwargs;
PyObject* astimezone;
int64_t millis;
if (max < 8) {
goto invalid;
Expand All @@ -2056,120 +2178,7 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
millis = (int64_t)BSON_UINT64_FROM_LE(millis);
*position += 8;

if (options->datetime_conversion == DATETIME_MS){
value = datetime_ms_from_millis(self, millis);
break;
}

int dt_clamp = options->datetime_conversion == DATETIME_CLAMP;
int dt_auto = options->datetime_conversion == DATETIME_AUTO;


if (dt_clamp || dt_auto){
PyObject *min_millis_fn_res;
PyObject *max_millis_fn_res;
int64_t min_millis;
int64_t max_millis;

if (options->tz_aware){
PyObject* tzinfo = options->tzinfo;
if (tzinfo == Py_None) {
// Default to UTC.
tzinfo = state->UTC;
}
min_millis_fn_res = PyObject_CallFunctionObjArgs(state->_min_datetime_ms, tzinfo, NULL);
max_millis_fn_res = PyObject_CallFunctionObjArgs(state->_max_datetime_ms, tzinfo, NULL);
} else {
min_millis_fn_res = PyObject_CallObject(state->_min_datetime_ms, NULL);
max_millis_fn_res = PyObject_CallObject(state->_max_datetime_ms, NULL);
}

if (!min_millis_fn_res || !max_millis_fn_res){
Py_XDECREF(min_millis_fn_res);
Py_XDECREF(max_millis_fn_res);
goto invalid;
}

min_millis = PyLong_AsLongLong(min_millis_fn_res);
max_millis = PyLong_AsLongLong(max_millis_fn_res);

if ((min_millis == -1 || max_millis == -1) && PyErr_Occurred())
{
// min/max_millis check
goto invalid;
}

if (dt_clamp) {
if (millis < min_millis) {
millis = min_millis;
} else if (millis > max_millis) {
millis = max_millis;
}
// Continues from here to return a datetime.
} else { // dt_auto
if (millis < min_millis || millis > max_millis){
value = datetime_ms_from_millis(self, millis);
break; // Out-of-range so done.
}
}
}

naive = datetime_from_millis(millis);
if (!options->tz_aware) { /* In the naive case, we're done here. */
value = naive;
break;
}

if (!naive) {
goto invalid;
}
replace = PyObject_GetAttr(naive, state->_replace_str);
Py_DECREF(naive);
if (!replace) {
goto invalid;
}
args = PyTuple_New(0);
if (!args) {
Py_DECREF(replace);
goto invalid;
}
kwargs = PyDict_New();
if (!kwargs) {
Py_DECREF(replace);
Py_DECREF(args);
goto invalid;
}
if (PyDict_SetItem(kwargs, state->_tzinfo_str, state->UTC) == -1) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}
value = PyObject_Call(replace, args, kwargs);
if (!value) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}

/* convert to local time */
if (options->tzinfo != Py_None) {
astimezone = PyObject_GetAttr(value, state->_astimezone_str);
Py_DECREF(value);
if (!astimezone) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}
value = PyObject_CallFunctionObjArgs(astimezone, options->tzinfo, NULL);
Py_DECREF(astimezone);
}

Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
value = decode_datetime(self, millis, options);
break;
}
case 11:
Expand Down Expand Up @@ -3041,6 +3050,8 @@ static int _cbson_traverse(PyObject *m, visitproc visit, void *arg) {
Py_VISIT(state->_from_uuid_str);
Py_VISIT(state->_as_uuid_str);
Py_VISIT(state->_from_bid_str);
Py_VISIT(state->min_datetime);
Py_VISIT(state->max_datetime);
return 0;
}

Expand Down Expand Up @@ -3085,6 +3096,8 @@ static int _cbson_clear(PyObject *m) {
Py_CLEAR(state->_from_uuid_str);
Py_CLEAR(state->_as_uuid_str);
Py_CLEAR(state->_from_bid_str);
Py_CLEAR(state->min_datetime);
Py_CLEAR(state->max_datetime);
return 0;
}

Expand Down
Loading