Skip to content

Commit 6a52112

Browse files
committed
Add support for timezones (fixes #12)
1 parent 7ea86a7 commit 6a52112

File tree

5 files changed

+79
-7
lines changed

5 files changed

+79
-7
lines changed

aiochsa/parser.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import namedtuple
12
import pkgutil
23
import simplejson as json
34
from typing import Iterable
@@ -21,6 +22,9 @@
2122
)
2223

2324

25+
EnumOption = namedtuple('EnumOption', ['label', 'value'])
26+
27+
2428
@v_args(inline=True)
2529
class TypeTransformer(Transformer):
2630

@@ -37,7 +41,17 @@ def aggregate_type(self, name, func, type_):
3741
return self._types[name](type_)
3842

3943
def simple_type(self, name, *params):
40-
return self._types[name]()
44+
return self._types[name](*params)
45+
46+
def enum_param(self, label, value):
47+
return EnumOption(label, value)
48+
49+
def STRING(self, value):
50+
assert value[0] == value[-1] == "'"
51+
return value[1:-1]
52+
53+
def INT(self, value):
54+
return int(value)
4155

4256

4357
def parse_type(types: TypeRegistry, type_str):

aiochsa/type.lark

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ start: _type
33
_type: composite_type | aggregate_type | simple_type
44
composite_type: COMPOSITE_NAME "(" _type ("," _type)* ")"
55
aggregate_type: AGGREGATE_NAME "(" SIMPLE_NAME "," _type ")"
6-
simple_type: SIMPLE_NAME [params]
7-
params: "(" PARAM ("," PARAM)* ")"
6+
simple_type: SIMPLE_NAME [_params]
7+
_params: "(" _param ("," _param)* ")"
8+
_param: enum_param | STRING | INT
9+
enum_param: STRING "=" INT
810

911
COMPOSITE_NAME: /Tuple|Array|Nullable|LowCardinality/
1012
AGGREGATE_NAME: /AggregateFunction|SimpleAggregateFunction/
1113
SIMPLE_NAME: /(?!Tuple|Array|Nullable|LowCardinality|AggregateFunction|SimpleAggregateFunction)\w+/
12-
PARAM: /('([^\\']|\\.)*'\s*=\s)?-?\d+/
14+
STRING: /'([^\\']|\\.)*'/
15+
INT: /-?\d+/
1316

1417
%ignore /[ \t\f\r\n]+/ // whitespace

aiochsa/types.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
)
77
from uuid import UUID
88

9+
try:
10+
import zoneinfo
11+
except ImportError:
12+
from backports import zoneinfo
13+
914

1015
NoneType = type(None)
1116
PyType = TypeVar('PyType')
@@ -47,6 +52,9 @@ def __repr__(self):
4752
class StrType(BaseType):
4853
py_type = str
4954

55+
def __init__(self, *params):
56+
pass
57+
5058
@classmethod
5159
def escape(cls, value: str, escape=None) -> str:
5260
value = value.replace('\\', '\\\\').replace("'", "\\'")
@@ -70,6 +78,9 @@ class FloatType(BaseType):
7078
class DecimalType(BaseType):
7179
py_type = Decimal
7280

81+
def __init__(self, *params):
82+
pass
83+
7384
@classmethod
7485
def escape(cls, value: PyType, escape: Callable) -> str:
7586
return f"'{value}'"
@@ -101,6 +112,14 @@ def from_json(self, value: str) -> Optional[date]:
101112
class DateTimeType(BaseType):
102113
py_type = datetime
103114

115+
__slots__ = ('_tzinfo',)
116+
117+
def __init__(self, tz_name=None):
118+
if tz_name is None:
119+
self._tzinfo = None
120+
else:
121+
self._tzinfo = zoneinfo.ZoneInfo(tz_name)
122+
104123
@classmethod
105124
def escape(cls, value: datetime, escape=None) -> str:
106125
value = value.replace(tzinfo=None, microsecond=0)
@@ -114,7 +133,10 @@ def to_json(cls, value: datetime, to_json: Callable) -> JsonType:
114133
def from_json(self, value: str) -> Optional[datetime]:
115134
if value == '0000-00-00 00:00:00':
116135
return None
117-
return datetime.fromisoformat(value)
136+
result = datetime.fromisoformat(value)
137+
if self._tzinfo is not None:
138+
result = result.replace(tzinfo=self._tzinfo)
139+
return result
118140

119141

120142
class DateTimeUTCType(DateTimeType):
@@ -144,7 +166,11 @@ def to_json(cls, value: datetime, to_json: Callable) -> JsonType:
144166
return value.isoformat()
145167

146168
def from_json(self, value: str) -> datetime:
147-
return datetime.fromisoformat(value).replace(tzinfo=timezone.utc)
169+
result = datetime.fromisoformat(value)
170+
if self._tzinfo is None:
171+
return result.replace(tzinfo=timezone.utc)
172+
else:
173+
return result.replace(tzinfo=self._tzinfo).astimezone(timezone.utc)
148174

149175

150176
class UUIDType(BaseType):

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ packages =
1212
aiochsa
1313
install_requires =
1414
aiohttp>=3.7.2,<4.0.0
15+
backports.zoneinfo;python_version<"3.9"
1516
clickhouse_sqlalchemy>=0.1.4
1617
lark-parser>=0.7.7
1718
simplejson>=3.16.0

tests/test_types.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from datetime import date, datetime, timezone
1+
from datetime import date, datetime, timedelta, timezone
22
from decimal import Decimal
33
import enum
44
from ipaddress import IPv4Address, IPv6Address
@@ -54,6 +54,7 @@ class CustomStr(str):
5454

5555
SaType = Union[sa.types.TypeEngine, Type[sa.types.TypeEngine]]
5656

57+
5758
def combine_typed_rapameters(spec_seq: Iterable[Tuple[SaType, Iterable]]):
5859
return list( # Wrap into list to make it reusable (iterator is one-off)
5960
itertools.chain(*[
@@ -184,6 +185,20 @@ async def test_zero_dates(clickhouse_version, conn, sa_type, value):
184185
assert result is None
185186

186187

188+
@pytest.mark.parametrize('tz_name,tz_offset', [
189+
('UTC', 0),
190+
('EST', -18_000),
191+
('Europe/Moscow', 10_800),
192+
])
193+
async def test_timezones(conn, tz_name, tz_offset):
194+
dt = datetime(2020, 1, 1)
195+
result = await conn.fetchval(
196+
sa.func.toTimeZone(sa.func.toDateTime(dt), tz_name).select()
197+
)
198+
assert result.utcoffset().total_seconds() == tz_offset
199+
assert result.astimezone(timezone.utc).replace(tzinfo=None) == dt
200+
201+
187202
@pytest.fixture
188203
async def conn_utc(dsn):
189204
types = TypeRegistry()
@@ -248,6 +263,19 @@ async def test_datetime_utc_insert_naive(conn_utc, table_for_type):
248263
)
249264

250265

266+
@pytest.mark.parametrize('tz_name,tz_offset', [
267+
('UTC', 0),
268+
('EST', -18_000),
269+
('Europe/Moscow', 10_800),
270+
])
271+
async def test_timezones_with_utc(conn_utc, tz_name, tz_offset):
272+
dt = datetime(2020, 1, 1, tzinfo=timezone.utc)
273+
result = await conn_utc.fetchval(
274+
sa.func.toTimeZone(sa.func.toDateTime(dt, 'UTC'), tz_name).select()
275+
)
276+
assert result == dt
277+
278+
251279
@pytest.mark.parametrize('value', [0, 4294967295])
252280
async def test_simple_aggregate_function(conn, recreate_table_for_type, value):
253281
table_name = await recreate_table_for_type(

0 commit comments

Comments
 (0)