Skip to content
283 changes: 158 additions & 125 deletions bson/_cbsonmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ struct module_state {
PyObject* Decimal128;
PyObject* Mapping;
PyObject* DatetimeMS;
PyObject* _min_datetime_ms;
PyObject* _max_datetime_ms;
PyObject* min_datetime;
PyObject* max_datetime;
PyObject* replace_args;
PyObject* replace_kwargs;
PyObject* _type_marker_str;
PyObject* _flags_str;
PyObject* _pattern_str;
Expand All @@ -80,6 +82,8 @@ struct module_state {
PyObject* _from_uuid_str;
PyObject* _as_uuid_str;
PyObject* _from_bid_str;
int64_t min_millis;
int64_t max_millis;
};

#define GETSTATE(m) ((struct module_state*)PyModule_GetState(m))
Expand Down Expand Up @@ -253,7 +257,7 @@ static PyObject* datetime_from_millis(long long millis) {
* 2. Multiply that by 1000: 253402300799000
* 3. Add in microseconds divided by 1000 253402300799999
*
* (Note: BSON doesn't support microsecond accuracy, hence the rounding.)
* (Note: BSON doesn't support microsecond accuracy, hence the truncation.)
*
* To decode we could do:
* 1. Get seconds: timestamp / 1000: 253402300799
Expand Down Expand Up @@ -376,6 +380,118 @@ static int millis_from_datetime_ms(PyObject* dt, long long* out){
return 1;
}

static PyObject* decode_datetime(PyObject* self, long long millis, const codec_options_t* options){
PyObject* naive = NULL;
PyObject* replace = NULL;
PyObject* args = NULL;
PyObject* kwargs = NULL;
PyObject* value = NULL;
struct module_state *state = GETSTATE(self);
if (options->datetime_conversion == DATETIME_MS){
return datetime_ms_from_millis(self, millis);
}

int dt_clamp = options->datetime_conversion == DATETIME_CLAMP;
int dt_auto = options->datetime_conversion == DATETIME_AUTO;

if (dt_clamp || dt_auto){
int64_t min_millis = state->min_millis;
int64_t max_millis = state->max_millis;
int64_t min_millis_offset = 0;
int64_t max_millis_offset = 0;
if (options->tz_aware && options->tzinfo && options->tzinfo != Py_None) {
PyObject* utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->min_datetime, NULL);
if (utcoffset == NULL) {
return 0;
}
if (utcoffset != Py_None) {
if (!PyDelta_Check(utcoffset)) {
PyObject* BSONError = _error("BSONError");
if (BSONError) {
PyErr_SetString(BSONError, "tzinfo.utcoffset() did not return a datetime.timedelta");
Py_DECREF(BSONError);
}
Py_DECREF(utcoffset);
return 0;
}
min_millis_offset = (PyDateTime_DELTA_GET_DAYS(utcoffset) * 86400 +
PyDateTime_DELTA_GET_SECONDS(utcoffset)) * 1000 +
(PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000);
}
Py_DECREF(utcoffset);
utcoffset = PyObject_CallMethodObjArgs(options->tzinfo, state->_utcoffset_str, state->max_datetime, NULL);
if (utcoffset == NULL) {
return 0;
}
if (utcoffset != Py_None) {
if (!PyDelta_Check(utcoffset)) {
PyObject* BSONError = _error("BSONError");
if (BSONError) {
PyErr_SetString(BSONError, "tzinfo.utcoffset() did not return a datetime.timedelta");
Py_DECREF(BSONError);
}
Py_DECREF(utcoffset);
return 0;
}
max_millis_offset = (PyDateTime_DELTA_GET_DAYS(utcoffset) * 86400 +
PyDateTime_DELTA_GET_SECONDS(utcoffset)) * 1000 +
(PyDateTime_DELTA_GET_MICROSECONDS(utcoffset) / 1000);
}
Py_DECREF(utcoffset);
}
if (min_millis_offset < 0) {
min_millis -= min_millis_offset;
}

if (max_millis_offset > 0) {
max_millis -= max_millis_offset;
}

if (dt_clamp) {
if (millis < min_millis) {
millis = min_millis;
} else if (millis > max_millis) {
millis = max_millis;
}
// Continues from here to return a datetime.
} else { // dt_auto
if (millis < min_millis || millis > max_millis){
return datetime_ms_from_millis(self, millis);
}
}
}

naive = datetime_from_millis(millis);
if (!naive) {
goto invalid;
}

if (!options->tz_aware) { /* In the naive case, we're done here. */
return naive;
}
replace = PyObject_GetAttr(naive, state->_replace_str);
if (!replace) {
goto invalid;
}
value = PyObject_Call(replace, state->replace_args, state->replace_kwargs);
if (!value) {
goto invalid;
}

/* convert to local time */
if (options->tzinfo != Py_None) {
PyObject* temp = PyObject_CallMethodObjArgs(value, state->_astimezone_str, options->tzinfo, NULL);
Py_DECREF(value);
value = temp;
}
invalid:
Py_XDECREF(naive);
Py_XDECREF(replace);
Py_XDECREF(args);
Py_XDECREF(kwargs);
return value;
}

/* Just make this compatible w/ the old API. */
int buffer_write_bytes(buffer_t buffer, const char* data, int size) {
if (pymongo_buffer_write(buffer, data, size)) {
Expand Down Expand Up @@ -482,6 +598,8 @@ static int _load_python_objects(PyObject* module) {
PyObject* empty_string = NULL;
PyObject* re_compile = NULL;
PyObject* compiled = NULL;
PyObject* min_datetime_ms = NULL;
PyObject* max_datetime_ms = NULL;
struct module_state *state = GETSTATE(module);
if (!state) {
return 1;
Expand Down Expand Up @@ -530,10 +648,34 @@ static int _load_python_objects(PyObject* module) {
_load_object(&state->UUID, "uuid", "UUID") ||
_load_object(&state->Mapping, "collections.abc", "Mapping") ||
_load_object(&state->DatetimeMS, "bson.datetime_ms", "DatetimeMS") ||
_load_object(&state->_min_datetime_ms, "bson.datetime_ms", "_min_datetime_ms") ||
_load_object(&state->_max_datetime_ms, "bson.datetime_ms", "_max_datetime_ms")) {
_load_object(&min_datetime_ms, "bson.datetime_ms", "_MIN_UTC_MS") ||
_load_object(&max_datetime_ms, "bson.datetime_ms", "_MAX_UTC_MS") ||
_load_object(&state->min_datetime, "bson.datetime_ms", "_MIN_UTC") ||
_load_object(&state->max_datetime, "bson.datetime_ms", "_MAX_UTC")) {
return 1;
}

state->min_millis = PyLong_AsLongLong(min_datetime_ms);
state->max_millis = PyLong_AsLongLong(max_datetime_ms);
Py_DECREF(min_datetime_ms);
Py_DECREF(max_datetime_ms);
if ((state->min_millis == -1 || state->max_millis == -1) && PyErr_Occurred()) {
return 1;
}

/* Speed up datetime.replace(tzinfo=utc) call */
state->replace_args = PyTuple_New(0);
if (!state->replace_args) {
return 1;
}
state->replace_kwargs = PyDict_New();
if (!state->replace_kwargs) {
return 1;
}
if (PyDict_SetItem(state->replace_kwargs, state->_tzinfo_str, state->UTC) == -1) {
return 1;
}

/* Reload our REType hack too. */
empty_string = PyBytes_FromString("");
if (empty_string == NULL) {
Expand Down Expand Up @@ -1247,15 +1389,16 @@ static int _write_element_to_buffer(PyObject* self, buffer_t buffer,
return 0;
if (utcoffset != Py_None) {
PyObject* result = PyNumber_Subtract(value, utcoffset);
Py_DECREF(utcoffset);
if (!result) {
Py_DECREF(utcoffset);
return 0;
}
millis = millis_from_datetime(result);
Py_DECREF(result);
} else {
millis = millis_from_datetime(value);
}
Py_DECREF(utcoffset);
*(pymongo_buffer_get_buffer(buffer) + type_byte) = 0x09;
return buffer_write_int64(buffer, (int64_t)millis);
} else if (PyObject_TypeCheck(value, state->REType)) {
Expand Down Expand Up @@ -2043,11 +2186,6 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
}
case 9:
{
PyObject* naive;
PyObject* replace;
PyObject* args;
PyObject* kwargs;
PyObject* astimezone;
int64_t millis;
if (max < 8) {
goto invalid;
Expand All @@ -2056,120 +2194,7 @@ static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer,
millis = (int64_t)BSON_UINT64_FROM_LE(millis);
*position += 8;

if (options->datetime_conversion == DATETIME_MS){
value = datetime_ms_from_millis(self, millis);
break;
}

int dt_clamp = options->datetime_conversion == DATETIME_CLAMP;
int dt_auto = options->datetime_conversion == DATETIME_AUTO;


if (dt_clamp || dt_auto){
PyObject *min_millis_fn_res;
PyObject *max_millis_fn_res;
int64_t min_millis;
int64_t max_millis;

if (options->tz_aware){
PyObject* tzinfo = options->tzinfo;
if (tzinfo == Py_None) {
// Default to UTC.
tzinfo = state->UTC;
}
min_millis_fn_res = PyObject_CallFunctionObjArgs(state->_min_datetime_ms, tzinfo, NULL);
max_millis_fn_res = PyObject_CallFunctionObjArgs(state->_max_datetime_ms, tzinfo, NULL);
} else {
min_millis_fn_res = PyObject_CallObject(state->_min_datetime_ms, NULL);
max_millis_fn_res = PyObject_CallObject(state->_max_datetime_ms, NULL);
}

if (!min_millis_fn_res || !max_millis_fn_res){
Py_XDECREF(min_millis_fn_res);
Py_XDECREF(max_millis_fn_res);
goto invalid;
}

min_millis = PyLong_AsLongLong(min_millis_fn_res);
max_millis = PyLong_AsLongLong(max_millis_fn_res);

if ((min_millis == -1 || max_millis == -1) && PyErr_Occurred())
{
// min/max_millis check
goto invalid;
}

if (dt_clamp) {
if (millis < min_millis) {
millis = min_millis;
} else if (millis > max_millis) {
millis = max_millis;
}
// Continues from here to return a datetime.
} else { // dt_auto
if (millis < min_millis || millis > max_millis){
value = datetime_ms_from_millis(self, millis);
break; // Out-of-range so done.
}
}
}

naive = datetime_from_millis(millis);
if (!options->tz_aware) { /* In the naive case, we're done here. */
value = naive;
break;
}

if (!naive) {
goto invalid;
}
replace = PyObject_GetAttr(naive, state->_replace_str);
Py_DECREF(naive);
if (!replace) {
goto invalid;
}
args = PyTuple_New(0);
if (!args) {
Py_DECREF(replace);
goto invalid;
}
kwargs = PyDict_New();
if (!kwargs) {
Py_DECREF(replace);
Py_DECREF(args);
goto invalid;
}
if (PyDict_SetItem(kwargs, state->_tzinfo_str, state->UTC) == -1) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}
value = PyObject_Call(replace, args, kwargs);
if (!value) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}

/* convert to local time */
if (options->tzinfo != Py_None) {
astimezone = PyObject_GetAttr(value, state->_astimezone_str);
Py_DECREF(value);
if (!astimezone) {
Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
goto invalid;
}
value = PyObject_CallFunctionObjArgs(astimezone, options->tzinfo, NULL);
Py_DECREF(astimezone);
}

Py_DECREF(replace);
Py_DECREF(args);
Py_DECREF(kwargs);
value = decode_datetime(self, millis, options);
break;
}
case 11:
Expand Down Expand Up @@ -3053,6 +3078,10 @@ static int _cbson_traverse(PyObject *m, visitproc visit, void *arg) {
Py_VISIT(state->_from_uuid_str);
Py_VISIT(state->_as_uuid_str);
Py_VISIT(state->_from_bid_str);
Py_VISIT(state->min_datetime);
Py_VISIT(state->max_datetime);
Py_VISIT(state->replace_args);
Py_VISIT(state->replace_kwargs);
return 0;
}

Expand Down Expand Up @@ -3097,6 +3126,10 @@ static int _cbson_clear(PyObject *m) {
Py_CLEAR(state->_from_uuid_str);
Py_CLEAR(state->_as_uuid_str);
Py_CLEAR(state->_from_bid_str);
Py_CLEAR(state->min_datetime);
Py_CLEAR(state->max_datetime);
Py_CLEAR(state->replace_args);
Py_CLEAR(state->replace_kwargs);
return 0;
}

Expand Down
Loading