diff --git a/CHANGELOG.md b/CHANGELOG.md index c238024..adda7f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +### 1.4.0 + +- feat: #119 Allow query results returned in columns and deserialized to `numpy` objects + ### 1.3.2 - feat(aggragation-function): add anyLast function. diff --git a/README.md b/README.md index aa93e87..679cce3 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Read [Documentation](https://github.com/jayvynl/django-clickhouse-backend/blob/m - Support most clickhouse data types. - Support [SETTINGS in SELECT Query](https://clickhouse.com/docs/en/sql-reference/statements/select/#settings-in-select-query). - Support [PREWHERE clause](https://clickhouse.com/docs/en/sql-reference/statements/select/prewhere). +- Support query results returned in columns and [deserialized to `numpy` objects](https://clickhouse-driver.readthedocs.io/en/latest/features.html#numpy-pandas-support). **Notes:** @@ -381,6 +382,60 @@ and [distributed table engine](https://clickhouse.com/docs/en/engines/table-engi The following example assumes that a cluster defined by [docker compose in this repository](https://github.com/jayvynl/django-clickhouse-backend/blob/main/compose.yaml) is used. This cluster name is `cluster`, it has 2 shards, every shard has 2 replica. +Query results returned as columns and/or deserialized into `numpy` objects +--- + +`clickhouse-driver` allows results to be returned as columns and/or deserialized into +`numpy` objects. This backend supports both options by using the context manager, +`Cursor.set_query_execution_args()`. + +```python +import numpy as np +from django.db import connection + +sql = """ + SELECT toDateTime32('2022-01-01 01:00:05', 'UTC'), number, number*2.5 + FROM system.numbers + LIMIT 3 +""" +with connection.cursor() as cursorWrapper: + with cursorWrapper.cursor.set_query_execution_args( + columnar=True, use_numpy=True + ) as cursor: + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchall(), + [ + np.array( + [ + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + ], + dtype="datetime64[s]", + ), + np.array([0, 1, 2], dtype=np.uint64), + np.array([0, 2.5, 5.0], dtype=np.float64), + ], + ) + + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchmany(2), + [ + np.array( + [ + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + ], + dtype="datetime64[s]", + ), + np.array([0, 1, 2], dtype=np.uint64), + ], + ) +``` + ### Configuration ```python diff --git a/clickhouse_backend/driver/connection.py b/clickhouse_backend/driver/connection.py index 766b1cf..efc10af 100644 --- a/clickhouse_backend/driver/connection.py +++ b/clickhouse_backend/driver/connection.py @@ -1,8 +1,11 @@ import re +import typing as T +from contextlib import contextmanager from clickhouse_driver import connection from clickhouse_driver.dbapi import connection as dbapi_connection from clickhouse_driver.dbapi import cursor, errors +from clickhouse_driver.result import IterQueryResult, ProgressQueryResult, QueryResult from django.conf import settings from .escape import escape_params @@ -70,6 +73,10 @@ def send_query(self, query, query_id=None, params=None): class Cursor(cursor.Cursor): + # Whether to return data in columnar format. For backwards-compatibility, + # let's default to None. + columnar = None + def close(self): """Push client back to connection pool""" if self.closed: @@ -81,12 +88,64 @@ def close(self): def closed(self): return self._state == self._states.CURSOR_CLOSED + @property + def use_numpy(self): + return self._client.client_settings["use_numpy"] + + @use_numpy.setter + def use_numpy(self, value): + self._client.client_settings["use_numpy"] = value + if value: + try: + from clickhouse_driver.numpy.result import ( + NumpyIterQueryResult, + NumpyProgressQueryResult, + NumpyQueryResult, + ) + + self._client.query_result_cls = NumpyQueryResult + self._client.iter_query_result_cls = NumpyIterQueryResult + self._client.progress_query_result_cls = NumpyProgressQueryResult + except ImportError as e: + raise RuntimeError("Extras for NumPy must be installed") from e + else: + self._client.query_result_cls = QueryResult + self._client.iter_query_result_cls = IterQueryResult + self._client.progress_query_result_cls = ProgressQueryResult + + @contextmanager + def set_query_execution_args( + self, columnar: T.Optional[bool] = None, use_numpy: T.Optional[bool] = None + ): + original_use_numpy = self.use_numpy + if use_numpy is not None: + self.use_numpy = use_numpy + + original_columnar = self.columnar + if columnar is not None: + self.columnar = columnar + + yield self + + self.use_numpy = original_use_numpy + self.columnar = original_columnar + def __del__(self): # If someone forgets calling close method, # then release connection when gc happens. if not self.closed: self.close() + def _prepare(self): + """Override clickhouse_driver.Cursor._prepare() to add columnar kwargs. + + See https://github.com/jayvynl/django-clickhouse-backend/issues/119 + """ + execute, execute_kwargs = super()._prepare() + if self.columnar is not None: + execute_kwargs["columnar"] = self.columnar + return execute, execute_kwargs + def execute(self, operation, parameters=None): """fix https://github.com/jayvynl/django-clickhouse-backend/issues/9""" if getattr( diff --git a/tests/backends/tests.py b/tests/backends/tests.py index 5a32a76..2bd399a 100644 --- a/tests/backends/tests.py +++ b/tests/backends/tests.py @@ -1,5 +1,6 @@ """Tests related to django.db.backends that haven't been organized.""" import datetime +import importlib import threading import unittest import warnings @@ -560,6 +561,208 @@ def test_timezone_none_use_tz_false(self): connection.init_connection_state() +def check_numpy(): + """Check if numpy is installed.""" + spec = importlib.util.find_spec("numpy") + return spec is not None + + +class ColumnarTestCase(TransactionTestCase): + available_apps = ["backends"] + databases = {"default", "s2r1"} + + def test_columnar_query(self): + sql = """ + SELECT number, number*2, number*3, number*4, number*5 + FROM system.numbers + LIMIT 10 + """ + with connections["s2r1"].cursor() as cursorWrapper: + with cursorWrapper.cursor.set_query_execution_args(columnar=True) as cursor: + cursor.execute(sql) + self.assertEqual( + cursor.fetchall(), + [ + (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + (0, 2, 4, 6, 8, 10, 12, 14, 16, 18), + (0, 3, 6, 9, 12, 15, 18, 21, 24, 27), + (0, 4, 8, 12, 16, 20, 24, 28, 32, 36), + (0, 5, 10, 15, 20, 25, 30, 35, 40, 45), + ], + ) + + cursor.execute(sql) + self.assertEqual( + cursor.fetchmany(2), + [ + (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + (0, 2, 4, 6, 8, 10, 12, 14, 16, 18), + ], + ) + + actual_results = [ + r + for results in iter(lambda: cursor.fetchmany(2), []) + for r in results + ] + self.assertEqual( + actual_results, + [ + (0, 3, 6, 9, 12, 15, 18, 21, 24, 27), + (0, 4, 8, 12, 16, 20, 24, 28, 32, 36), + (0, 5, 10, 15, 20, 25, 30, 35, 40, 45), + ], + ) + + cursor.execute(sql) + self.assertEqual( + cursor.fetchone(), + (0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ) + + @unittest.skipUnless(check_numpy(), "numpy is not installed") + def test_use_numpy_query(self): + sql = """ + SELECT toDateTime32('2022-01-01 01:00:05', 'UTC'), number, number*2.5 + FROM system.numbers + LIMIT 3 + """ + import numpy as np + + with connections["s2r1"].cursor() as cursorWrapper: + with cursorWrapper.cursor.set_query_execution_args( + columnar=True, use_numpy=True + ) as cursor: + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchall(), + [ + np.array( + [ + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + ], + dtype="datetime64[s]", + ), + np.array([0, 1, 2], dtype=np.uint64), + np.array([0, 2.5, 5.0], dtype=np.float64), + ], + ) + + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchmany(2), + [ + np.array( + [ + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + ], + dtype="datetime64[s]", + ), + np.array([0, 1, 2], dtype=np.uint64), + ], + ) + + actual_results = [ + r + for results in iter(lambda: cursor.fetchmany(2), []) + for r in results + ] + np.testing.assert_equal( + actual_results, + [ + np.array([0, 2.5, 5], dtype=np.float64), + ], + ) + + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchone(), + np.array( + [ + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + np.datetime64("2022-01-01T01:00:05"), + ], + dtype="datetime64[s]", + ), + ) + + @unittest.skipUnless(check_numpy(), "numpy is not installed") + def test_use_numpy_but_not_columnar_format(self): + sql = """ + SELECT toDateTime32('2022-01-01 01:00:05', 'UTC'), number, number*2.5 + FROM system.numbers + LIMIT 3 + """ + import numpy as np + + with connections["s2r1"].cursor() as cursorWrapper: + with cursorWrapper.cursor.set_query_execution_args( + columnar=False, use_numpy=True + ) as cursor: + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchall(), + [ + np.array( + [datetime.datetime(2022, 1, 1, 1, 0, 5), 0, 0.0], + dtype=object, + ), + np.array( + [datetime.datetime(2022, 1, 1, 1, 0, 5), 1, 2.5], + dtype=object, + ), + np.array( + [datetime.datetime(2022, 1, 1, 1, 0, 5), 2, 5.0], + dtype=object, + ), + ], + ) + + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchmany(2), + [ + np.array( + [datetime.datetime(2022, 1, 1, 1, 0, 5), 0, 0.0], + dtype="object", + ), + np.array( + [datetime.datetime(2022, 1, 1, 1, 0, 5), 1, 2.5], + dtype="object", + ), + ], + ) + + actual_results = [ + r + for results in iter(lambda: cursor.fetchmany(2), []) + for r in results + ] + np.testing.assert_equal( + actual_results, + [ + np.array( + [datetime.datetime(2022, 1, 1, 1, 0, 5), 2, 5.0], + dtype="object", + ), + ], + ) + + cursor.execute(sql) + np.testing.assert_equal( + cursor.fetchone(), + np.array( + [datetime.datetime(2022, 1, 1, 1, 0, 5), 0, 0.0], + dtype="object", + ), + ) + + # These tests aren't conditional because it would require differentiating # between MySQL+InnoDB and MySQL+MYISAM (something we currently can't do). class FkConstraintsTests(TransactionTestCase): diff --git a/tox.ini b/tox.ini index 2b494cf..a7761ac 100644 --- a/tox.ini +++ b/tox.ini @@ -16,6 +16,7 @@ deps = django5.1: Django>=5.1,<5.2 coverage commands = + pip install pandas # Use local clickhouse_backend package so that coverage works properly. pip install -e . coverage run tests/runtests.py --debug-sql {posargs}