diff --git a/doc/aerospike.rst b/doc/aerospike.rst index d637138c7c..8b40a52bd3 100644 --- a/doc/aerospike.rst +++ b/doc/aerospike.rst @@ -329,6 +329,14 @@ Other digest = aerospike.calc_digest("test", "demo", 1 ) pp.pprint(digest) +.. py:function:: get_partition_id(digest) -> int + + Calculate the partition ID using a digest. + + :param bytes-like object digest: a record digest. It can be calculated using :py:meth:`aerospike.calc_digest`. + :return: the partition ID for the digest + :rtype: :class:`int` + .. _client_config: Client Configuration diff --git a/src/main/aerospike.c b/src/main/aerospike.c index 25815d35dc..9e9e73615b 100644 --- a/src/main/aerospike.c +++ b/src/main/aerospike.c @@ -87,7 +87,7 @@ static PyMethodDef aerospike_methods[] = { METH_VARARGS | METH_KEYWORDS, "Calculate the digest of a key"}, //Get partition ID for given digest - {"get_partition_id", (PyCFunction)Aerospike_Get_Partition_Id, METH_VARARGS, + {"get_partition_id", (PyCFunction)Aerospike_Get_Partition_Id, METH_O, "Get partition ID for given digest"}, {NULL}}; diff --git a/src/main/calc_digest.c b/src/main/calc_digest.c index b6e990e98d..304b47247e 100644 --- a/src/main/calc_digest.c +++ b/src/main/calc_digest.c @@ -131,20 +131,40 @@ PyObject *Aerospike_Calc_Digest(PyObject *self, PyObject *args, PyObject *kwds) return Aerospike_Calc_Digest_Invoke(py_ns, py_set, py_key); } -PyObject *Aerospike_Get_Partition_Id(PyObject *self, PyObject *args) +PyObject *Aerospike_Get_Partition_Id(PyObject *self, PyObject *arg) { - // Python Function Arguments - as_digest_value digest; + Py_buffer py_buffer; + PyObject *py_retval = NULL; - // Python Function Argument Parsing - if (PyArg_Parse(args, "(s)", &digest) == false) { - return NULL; + if (PyArg_Parse(arg, "y*", &py_buffer) == false) { + goto exit; } - uint32_t part_id = 0; + as_error err; + as_error_init(&err); + + if (py_buffer.len != 20) { + as_error_update(&err, AEROSPIKE_ERR_PARAM, + "Digest must be 20 bytes long"); + goto CLEANUP_AND_EXIT; + } - part_id = as_partition_getid(digest, 4096); + uint32_t part_id = as_partition_getid(py_buffer.buf, 4096); - // Invoke Operation - return PyLong_FromLong(part_id); + py_retval = PyLong_FromLong(part_id); + if (!py_retval) { + as_error_update(&err, AEROSPIKE_ERR_CLIENT, + "Failed to retrieve partition id"); + goto CLEANUP_AND_EXIT; + } + +CLEANUP_AND_EXIT: + PyBuffer_Release(&py_buffer); + + if (err.code != AEROSPIKE_OK) { + raise_exception(&err); + } + +exit: + return py_retval; } diff --git a/test/new_tests/test_get_partition_id.py b/test/new_tests/test_get_partition_id.py new file mode 100644 index 0000000000..37886dca44 --- /dev/null +++ b/test/new_tests/test_get_partition_id.py @@ -0,0 +1,25 @@ +from aerospike import exception as e +import aerospike +import pytest + +# This isn't a correctness test. It's only for code coverage purposes +# and to make sure the API is aligned with the documentation +class TestGetPartitionID: + def test_basic_usage(self): + digest = aerospike.calc_digest("test", "demo", 1) + part_id = aerospike.get_partition_id(digest) + assert type(part_id) == int + + @pytest.mark.parametrize( + "digest, expected_exception", + [ + # Digests must be exactly 20 bytes long + (bytearray([0] * 21), e.ParamError), + (bytearray([0] * 19), e.ParamError), + # Does not accept strings + ("1" * 20, TypeError) + ] + ) + def test_invalid_digest(self, digest, expected_exception): + with pytest.raises(expected_exception): + aerospike.get_partition_id(digest)