diff --git a/gmp.c b/gmp.c index 91ff90e..5f27862 100644 --- a/gmp.c +++ b/gmp.c @@ -691,14 +691,36 @@ to_float(PyObject *self) static PyObject * richcompare(PyObject *self, PyObject *other, int op) { - MPZ_Object *u = (MPZ_Object *)self, *v = NULL; + zz_t *u = &((MPZ_Object *)self)->z; + zz_ord r; - assert(MPZ_Check(self)); - CHECK_OP(v, other); + if (MPZ_Check(other)) { + r = zz_cmp(u, &((MPZ_Object *)other)->z); + } + else if (PyLong_Check(other)) { + int32_t v; - zz_ord r = zz_cmp(&u->z, &v->z); + if (PyLong_AsInt32(other, &v) == 0) { + r = zz_cmp_i32(u, v); + } + else { + PyErr_Clear(); - Py_DECREF(v); + MPZ_Object *v = MPZ_from_int(other); + + if (!v) { + goto end; /* LCOV_EXCL_LINE */ + } + r = zz_cmp(u, &v->z); + Py_DECREF(v); + } + } + else if (Number_Check(other)) { + goto numbers; + } + else { + goto fallback; + } switch (op) { case Py_LT: return PyBool_FromLong(r == ZZ_LT); @@ -861,7 +883,129 @@ to_bool(PyObject *self) return res; \ } -BINOP(add, PyNumber_Add) +#define CHECK_OPv2(u, a, iu) \ + if (MPZ_Check(a)) { \ + u = (MPZ_Object *)a; \ + Py_INCREF(u); \ + } \ + else if (PyLong_Check(a)) { \ + iu = true; \ + } \ + else if (Number_Check(a)) { \ + goto numbers; \ + } \ + else { \ + goto fallback; \ + } + +#define BINOPv2(suff, slot) \ + static PyObject * \ + nb_##suff(PyObject *self, PyObject *other) \ + { \ + PyObject *res = NULL; \ + MPZ_Object *u = NULL, *v = NULL; \ + bool iu = false, iv = false; \ + \ + CHECK_OPv2(u, self, iu); \ + CHECK_OPv2(v, other, iv); \ + \ + res = (PyObject *)MPZ_new(0); \ + if (!res) { \ + goto end; \ + } \ + \ + zz_err ret = ZZ_OK; \ + \ + if (iu) { \ + int32_t x; \ + \ + if (PyLong_AsInt32(self, &x) == 0) { \ + ret = zz_i32_##suff(x, &v->z, \ + &((MPZ_Object *)res)->z); \ + goto done; \ + } \ + else { \ + PyErr_Clear(); \ + u = MPZ_from_int(self); \ + if (!u) { \ + goto end; \ + } \ + } \ + } \ + if (iv) { \ + int32_t x; \ + \ + if (PyLong_AsInt32(other, &x) == 0) { \ + ret = zz_##suff##_i32(&u->z, x, \ + &((MPZ_Object *)res)->z); \ + goto done; \ + } \ + else { \ + PyErr_Clear(); \ + v = MPZ_from_int(other); \ + if (!v) { \ + goto end; \ + } \ + } \ + } \ + ret = zz_##suff(&u->z, &v->z, &((MPZ_Object *)res)->z); \ +done: \ + if (ret == ZZ_OK) { \ + goto end; \ + } \ + if (ret == ZZ_VAL) { \ + Py_CLEAR(res); \ + PyErr_SetString(PyExc_ZeroDivisionError, \ + "division by zero"); \ + } \ + else { \ + Py_CLEAR(res); \ + PyErr_NoMemory(); \ + } \ + end: \ + Py_XDECREF(u); \ + Py_XDECREF(v); \ + return res; \ + fallback: \ + Py_XDECREF(u); \ + Py_XDECREF(v); \ + Py_RETURN_NOTIMPLEMENTED; \ + numbers: \ + Py_XDECREF(u); \ + Py_XDECREF(v); \ + \ + PyObject *uf, *vf; \ + \ + if (Number_Check(self)) { \ + uf = self; \ + Py_INCREF(uf); \ + } \ + else { \ + uf = to_float(self); \ + if (!uf) { \ + return NULL; \ + } \ + } \ + if (Number_Check(other)) { \ + vf = other; \ + Py_INCREF(vf); \ + } \ + else { \ + vf = to_float(other); \ + if (!vf) { \ + Py_DECREF(uf); \ + return NULL; \ + } \ + } \ + res = slot(uf, vf); \ + Py_DECREF(uf); \ + Py_DECREF(vf); \ + return res; \ + } + +#define zz_i32_add(x, y, r) zz_add_i32((y), (x), (r)) + +BINOPv2(add, PyNumber_Add) BINOP(sub, PyNumber_Subtract) BINOP(mul, PyNumber_Multiply)