Skip to content
This repository was archived by the owner on May 5, 2022. It is now read-only.

Commit 8e7f8a6

Browse files
committed
feat: split function
1 parent e8cbcfa commit 8e7f8a6

File tree

2 files changed

+64
-13
lines changed

2 files changed

+64
-13
lines changed

sqlalchemy_trino/datatype.py

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
from typing import *
23

34
from sqlalchemy import util
45
from sqlalchemy.sql import sqltypes
@@ -53,24 +54,74 @@
5354
}
5455

5556

56-
def parse_sqltype(type_str: str, column: str) -> TypeEngine:
57-
type_str = type_str.lower()
58-
m = re.match(r'^([\w\s]+)(?:\(([\d,\s]*)\))?', type_str)
59-
if m is None:
57+
class MAP(TypeEngine):
58+
pass
59+
60+
61+
class ROW(TypeEngine):
62+
pass
63+
64+
65+
def split(string: str, delimiter: str = ',',
66+
quote: str = '"', escaped_quote: str = r'\"',
67+
open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]:
68+
"""
69+
A split function that is aware of quotes and brackets/parentheses.
70+
71+
:param string: string to split
72+
:param delimiter: string defining where to split, usually a comma or space
73+
:param quote: string, either a single or a double quote
74+
:param escaped_quote: string representing an escaped quote
75+
:param open_bracket: string, either [, {, < or (
76+
:param close_bracket: string, either ], }, > or )
77+
"""
78+
parens = 0
79+
quotes = False
80+
i = 0
81+
for j, character in enumerate(string):
82+
complete = parens == 0 and not quotes
83+
if complete and character == delimiter:
84+
yield string[i:j]
85+
i = j + len(delimiter)
86+
elif character == open_bracket:
87+
parens += 1
88+
elif character == close_bracket:
89+
parens -= 1
90+
elif character == quote:
91+
if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote:
92+
quotes = False
93+
elif not quotes:
94+
quotes = True
95+
yield string[i:]
96+
97+
98+
def parse_sqltype(type_str: str) -> TypeEngine:
99+
type_str = type_str.strip().lower()
100+
match = re.match(r'^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?', type_str)
101+
if not match:
60102
util.warn(f"Could not parse type name '{type_str}'")
61103
return sqltypes.NULLTYPE
62-
type_name, type_opts = m.groups() # type: str, str
63-
type_name = type_name.strip()
104+
type_name = match.group("type")
105+
type_opts = match.group("options")
106+
107+
if type_name == "array":
108+
item_type = parse_sqltype(type_opts)
109+
return sqltypes.ARRAY(item_type)
110+
elif type_name == "map":
111+
key_type_str, value_type_str = split(type_opts)
112+
key_type = parse_sqltype(key_type_str)
113+
value_type = parse_sqltype(value_type_str)
114+
return MAP(key_type, value_type)
115+
elif type_name == "row":
116+
attr_types = split(type_opts)
117+
return ROW() # TODO
118+
64119
if type_name not in _type_map:
65-
util.warn(f"Did not recognize type '{type_name}' of column '{column}'")
120+
util.warn(f"Did not recognize type '{type_name}'")
66121
return sqltypes.NULLTYPE
67122
type_class = _type_map[type_name]
68123
type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else []
69124
if type_name in ('time', 'timestamp'):
70125
type_kwargs = dict(timezone=type_str.endswith("with time zone"))
71-
# TODO: handle time/timestamp(p) precision
72-
return type_class(**type_kwargs)
73-
if type_name in ('array', 'map', 'row'):
74-
# TODO
75-
return sqltypes.NULLTYPE
126+
return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision
76127
return type_class(*type_args)

sqlalchemy_trino/dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get_columns(self, connection: Connection,
9797
for row in rows:
9898
columns.append(dict(
9999
name=row.Column,
100-
type=datatype.parse_sqltype(row.Type, row.Column),
100+
type=datatype.parse_sqltype(row.Type),
101101
nullable=getattr(row, 'Null', True),
102102
default=None,
103103
))

0 commit comments

Comments
 (0)