Skip to content

Commit 75e1dbd

Browse files
committed
Add comb()
Closes #224
1 parent 5ab1b11 commit 75e1dbd

File tree

4 files changed

+102
-1
lines changed

4 files changed

+102
-1
lines changed

gmp.c

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,52 @@ MAKE_MPZ_UI_FUN(fac)
18341834
MAKE_MPZ_UI_FUN(fac2)
18351835
MAKE_MPZ_UI_FUN(fib)
18361836

1837+
static PyObject *
1838+
gmp_comb(PyObject *self, PyObject *const *args, Py_ssize_t nargs)
1839+
{
1840+
if (nargs != 2) {
1841+
PyErr_SetString(PyExc_TypeError, "two arguments required");
1842+
return NULL;
1843+
}
1844+
1845+
MPZ_Object *x, *y, *res = MPZ_new(0);
1846+
1847+
if (!res) {
1848+
return NULL; /* LCOV_EXCL_LINE */
1849+
}
1850+
CHECK_OP_INT(x, args[0]);
1851+
CHECK_OP_INT(y, args[1]);
1852+
if (zz_isneg(&x->z) || zz_isneg(&y->z)) {
1853+
PyErr_SetString(PyExc_ValueError,
1854+
"comb() not defined for negative values");
1855+
goto err;
1856+
}
1857+
1858+
int64_t n, k;
1859+
1860+
if ((zz_to_i64(&x->z, &n) || n > ULONG_MAX)
1861+
|| (zz_to_i64(&y->z, &k) || k > ULONG_MAX))
1862+
{
1863+
PyErr_Format(PyExc_OverflowError,
1864+
"comb() arguments should not exceed %ld",
1865+
ULONG_MAX);
1866+
goto err;
1867+
}
1868+
Py_XDECREF(x);
1869+
Py_XDECREF(y);
1870+
if (zz_bin((uint64_t)n, (uint64_t)k, &res->z)) {
1871+
/* LCOV_EXCL_START */
1872+
PyErr_NoMemory();
1873+
goto err;
1874+
/* LCOV_EXCL_STOP */
1875+
}
1876+
return (PyObject *)res;
1877+
err:
1878+
end:
1879+
Py_DECREF(res);
1880+
return NULL;
1881+
}
1882+
18371883
static zz_rnd
18381884
get_round_mode(PyObject *rndstr)
18391885
{
@@ -2040,6 +2086,10 @@ static PyMethodDef gmp_functions[] = {
20402086
{"fib", gmp_fib, METH_O,
20412087
("fib($module, n, /)\n--\n\n"
20422088
"Return the n-th Fibonacci number.")},
2089+
{"comb", (PyCFunction)gmp_comb, METH_FASTCALL,
2090+
("comb($module, n, k, /)\n--\n\nNumber of ways to choose k"
2091+
" items from n items without repetition and order.\n\n"
2092+
"Also called the binomial coefficient.")},
20432093
{"_mpmath_normalize", (PyCFunction)gmp__mpmath_normalize, METH_FASTCALL,
20442094
NULL},
20452095
{"_mpmath_create", (PyCFunction)gmp__mpmath_create, METH_FASTCALL, NULL},
@@ -2131,7 +2181,8 @@ gmp_exec(PyObject *m)
21312181
const char *str = ("import numbers, importlib.metadata as imp\n"
21322182
"numbers.Integral.register(gmp.mpz)\n"
21332183
"gmp.fac = gmp.factorial\n"
2134-
"gmp.__all__ = ['factorial', 'gcd', 'isqrt', 'mpz']\n"
2184+
"gmp.__all__ = ['comb', 'factorial', 'gcd', 'isqrt',\n"
2185+
" 'mpz']\n"
21352186
"gmp.__version__ = imp.version('python-gmp')\n");
21362187

21372188
PyObject *res = PyRun_String(str, Py_file_input, ns, ns);

tests/test_functions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from gmp import (
66
_mpmath_create,
77
_mpmath_normalize,
8+
comb,
89
double_fac,
910
fac,
1011
factorial,
@@ -45,6 +46,18 @@ def test_factorials(x):
4546
assert fm(x) == r
4647

4748

49+
@given(integers(min_value=0, max_value=12345),
50+
integers(min_value=0, max_value=12345))
51+
def test_comb(x, y):
52+
mx = mpz(x)
53+
my = mpz(y)
54+
r = math.comb(x, y)
55+
assert comb(mx, my) == r
56+
assert comb(mx, y) == r
57+
assert comb(x, my) == r
58+
assert comb(x, y) == r
59+
60+
4861
@given(bigints(), bigints(), bigints())
4962
@example(1<<(67*2), 1<<65, 1)
5063
@example(123, 1<<70, 1)
@@ -139,6 +152,16 @@ def test_interfaces():
139152
fac(-1)
140153
with pytest.raises(OverflowError):
141154
fac(2**1000)
155+
with pytest.raises(TypeError):
156+
comb(123)
157+
with pytest.raises(ValueError, match="not defined for negative values"):
158+
comb(-1, 2)
159+
with pytest.raises(ValueError, match="not defined for negative values"):
160+
comb(2, -1)
161+
with pytest.raises(OverflowError):
162+
comb(2**1000, 1)
163+
with pytest.raises(OverflowError):
164+
comb(1, 2**1000)
142165
with pytest.raises(TypeError):
143166
_mpmath_create(1j)
144167
with pytest.raises(TypeError):

zz.c

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2131,6 +2131,31 @@ MK_ZZ_FUNC_UL(fac, fac_ui)
21312131
MK_ZZ_FUNC_UL(fac2, 2fac_ui)
21322132
MK_ZZ_FUNC_UL(fib, fib_ui)
21332133

2134+
zz_err
2135+
zz_bin(uint64_t n, uint64_t k, zz_t *v)
2136+
{
2137+
if (n > ULONG_MAX || k > ULONG_MAX) {
2138+
return ZZ_BUF;
2139+
}
2140+
if (TMP_OVERFLOW) {
2141+
return ZZ_MEM; /* LCOV_EXCL_LINE */
2142+
}
2143+
2144+
mpz_t z;
2145+
2146+
mpz_init(z);
2147+
mpz_bin_uiui(z, (unsigned long)n, (unsigned long)k);
2148+
if (zz_resize(z->_mp_size, v) == ZZ_MEM) {
2149+
/* LCOV_EXCL_START */
2150+
mpz_clear(z);
2151+
return ZZ_MEM;
2152+
/* LCOV_EXCL_STOP */
2153+
}
2154+
mpn_copyi(v->digits, z->_mp_d, z->_mp_size);
2155+
mpz_clear(z);
2156+
return ZZ_OK;
2157+
}
2158+
21342159
zz_err
21352160
_zz_mpmath_normalize(zz_bitcnt_t prec, zz_rnd rnd, bool *negative,
21362161
zz_t *man, zz_t *exp, zz_bitcnt_t *bc)

zz.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ zz_err zz_fac(uint64_t u, zz_t *v);
131131
zz_err zz_fac2(uint64_t u, zz_t *v);
132132
zz_err zz_fib(uint64_t u, zz_t *v);
133133

134+
zz_err zz_bin(uint64_t n, uint64_t k, zz_t *v);
135+
134136
zz_err _zz_mpmath_normalize(zz_bitcnt_t prec, zz_rnd rnd, bool *negative,
135137
zz_t *man, zz_t *exp, zz_bitcnt_t *bc);
136138

0 commit comments

Comments
 (0)