Skip to content

Commit 4434620

Browse files
committed
chore: Use context manager
1 parent c25f507 commit 4434620

File tree

2 files changed

+41
-34
lines changed

2 files changed

+41
-34
lines changed

clickhouse_backend/driver/connection.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from contextlib import contextmanager
23

34
from clickhouse_driver import connection
45
from clickhouse_driver.dbapi import connection as dbapi_connection
@@ -74,7 +75,6 @@ class Cursor(cursor.Cursor):
7475
# Whether to return data in columnar format. For backwards-compatibility,
7576
# let's default to None.
7677
columnar = None
77-
_use_numpy = None
7878

7979
def close(self):
8080
"""Push client back to connection pool"""
@@ -89,13 +89,10 @@ def closed(self):
8989

9090
@property
9191
def use_numpy(self):
92-
if self._use_numpy is None:
93-
return self._client.client_settings["use_numpy"]
94-
return self._use_numpy
92+
return self._client.client_settings["use_numpy"]
9593

9694
@use_numpy.setter
9795
def use_numpy(self, value):
98-
self._use_numpy = value
9996
self._client.client_settings["use_numpy"] = value
10097
if value:
10198
try:
@@ -113,6 +110,19 @@ def use_numpy(self, value):
113110
self._client.query_result_cls = QueryResult
114111
self._client.iter_query_result_cls = IterQueryResult
115112
self._client.progress_query_result_cls = ProgressQueryResult
113+
114+
@contextmanager
115+
def set_query_args(self, columnar: bool, use_numpy: bool):
116+
original_use_numpy = self.use_numpy
117+
self.use_numpy = use_numpy
118+
original_columnar = self.columnar
119+
self.columnar = columnar
120+
121+
yield self
122+
123+
self.use_numpy = original_use_numpy
124+
self.columnar = original_columnar
125+
116126

117127
def __del__(self):
118128
# If someone forgets calling close method,

tests/backends/tests.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -635,46 +635,43 @@ def test_use_numpy_query(self):
635635
import numpy as np
636636

637637
with connections["s2r1"].cursor() as cursorWrapper:
638-
cursorWrapper.cursor.columnar = True
639-
cursorWrapper.cursor.use_numpy = True
640-
cursorWrapper.execute(sql)
641-
np.testing.assert_equal(
642-
cursorWrapper.fetchall(),
638+
with cursorWrapper.cursor.set_query_args(columnar=True, use_numpy=True) as cursor:
639+
cursor.execute(sql)
640+
np.testing.assert_equal(
641+
cursor.fetchall(),
643642
[
644643
np.array([np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05')], dtype='datetime64[s]'),
645644
np.array([0, 1, 2], dtype=np.uint64),
646645
np.array([0, 2.5, 5.0], dtype=np.float64)
647646
],
648647
)
649648

650-
cursorWrapper.execute(sql)
651-
np.testing.assert_equal(
652-
cursorWrapper.fetchmany(2),
653-
[
654-
np.array([np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05')], dtype='datetime64[s]'),
655-
np.array([0, 1, 2], dtype=np.uint64),
649+
cursor.execute(sql)
650+
np.testing.assert_equal(
651+
cursor.fetchmany(2),
652+
[
653+
np.array([np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05')], dtype='datetime64[s]'),
654+
np.array([0, 1, 2], dtype=np.uint64),
656655
],
657656
)
658657

659-
actual_results = [
660-
r
661-
for results in iter(lambda: cursorWrapper.fetchmany(2), [])
662-
for r in results
663-
]
664-
np.testing.assert_equal(
665-
actual_results,
666-
[
667-
np.array([0, 2.5, 5], dtype=np.float64),
668-
],
669-
)
658+
actual_results = [
659+
r
660+
for results in iter(lambda: cursor.fetchmany(2), [])
661+
for r in results
662+
]
663+
np.testing.assert_equal(
664+
actual_results,
665+
[
666+
np.array([0, 2.5, 5], dtype=np.float64),
667+
],
668+
)
670669

671-
cursorWrapper.execute(sql)
672-
np.testing.assert_equal(
673-
cursorWrapper.fetchone(),
674-
np.array([np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05')], dtype='datetime64[s]'),
675-
)
676-
cursorWrapper.cursor.columnar = False
677-
cursorWrapper.cursor.use_numpy = False
670+
cursor.execute(sql)
671+
np.testing.assert_equal(
672+
cursor.fetchone(),
673+
np.array([np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05'), np.datetime64('2022-01-01T01:00:05')], dtype='datetime64[s]'),
674+
)
678675

679676
# These tests aren't conditional because it would require differentiating
680677
# between MySQL+InnoDB and MySQL+MYISAM (something we currently can't do).

0 commit comments

Comments
 (0)