|
1 | 1 | import re
|
| 2 | +from typing import * |
2 | 3 |
|
3 | 4 | from sqlalchemy import util
|
4 | 5 | from sqlalchemy.sql import sqltypes
|
|
53 | 54 | }
|
54 | 55 |
|
55 | 56 |
|
56 |
| -def parse_sqltype(type_str: str, column: str) -> TypeEngine: |
57 |
| - type_str = type_str.lower() |
58 |
| - m = re.match(r'^([\w\s]+)(?:\(([\d,\s]*)\))?', type_str) |
59 |
| - if m is None: |
| 57 | +class MAP(TypeEngine): |
| 58 | + pass |
| 59 | + |
| 60 | + |
| 61 | +class ROW(TypeEngine): |
| 62 | + pass |
| 63 | + |
| 64 | + |
| 65 | +def split(string: str, delimiter: str = ',', |
| 66 | + quote: str = '"', escaped_quote: str = r'\"', |
| 67 | + open_bracket: str = '(', close_bracket: str = ')') -> Iterator[str]: |
| 68 | + """ |
| 69 | + A split function that is aware of quotes and brackets/parentheses. |
| 70 | +
|
| 71 | + :param string: string to split |
| 72 | + :param delimiter: string defining where to split, usually a comma or space |
| 73 | + :param quote: string, either a single or a double quote |
| 74 | + :param escaped_quote: string representing an escaped quote |
| 75 | + :param open_bracket: string, either [, {, < or ( |
| 76 | + :param close_bracket: string, either ], }, > or ) |
| 77 | + """ |
| 78 | + parens = 0 |
| 79 | + quotes = False |
| 80 | + i = 0 |
| 81 | + for j, character in enumerate(string): |
| 82 | + complete = parens == 0 and not quotes |
| 83 | + if complete and character == delimiter: |
| 84 | + yield string[i:j] |
| 85 | + i = j + len(delimiter) |
| 86 | + elif character == open_bracket: |
| 87 | + parens += 1 |
| 88 | + elif character == close_bracket: |
| 89 | + parens -= 1 |
| 90 | + elif character == quote: |
| 91 | + if quotes and string[j - len(escaped_quote) + 1: j + 1] != escaped_quote: |
| 92 | + quotes = False |
| 93 | + elif not quotes: |
| 94 | + quotes = True |
| 95 | + yield string[i:] |
| 96 | + |
| 97 | + |
| 98 | +def parse_sqltype(type_str: str) -> TypeEngine: |
| 99 | + type_str = type_str.strip().lower() |
| 100 | + match = re.match(r'^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?', type_str) |
| 101 | + if not match: |
60 | 102 | util.warn(f"Could not parse type name '{type_str}'")
|
61 | 103 | return sqltypes.NULLTYPE
|
62 |
| - type_name, type_opts = m.groups() # type: str, str |
63 |
| - type_name = type_name.strip() |
| 104 | + type_name = match.group("type") |
| 105 | + type_opts = match.group("options") |
| 106 | + |
| 107 | + if type_name == "array": |
| 108 | + item_type = parse_sqltype(type_opts) |
| 109 | + return sqltypes.ARRAY(item_type) |
| 110 | + elif type_name == "map": |
| 111 | + key_type_str, value_type_str = split(type_opts) |
| 112 | + key_type = parse_sqltype(key_type_str) |
| 113 | + value_type = parse_sqltype(value_type_str) |
| 114 | + return MAP(key_type, value_type) |
| 115 | + elif type_name == "row": |
| 116 | + attr_types = split(type_opts) |
| 117 | + return ROW() # TODO |
| 118 | + |
64 | 119 | if type_name not in _type_map:
|
65 |
| - util.warn(f"Did not recognize type '{type_name}' of column '{column}'") |
| 120 | + util.warn(f"Did not recognize type '{type_name}'") |
66 | 121 | return sqltypes.NULLTYPE
|
67 | 122 | type_class = _type_map[type_name]
|
68 | 123 | type_args = [int(o.strip()) for o in type_opts.split(',')] if type_opts else []
|
69 | 124 | if type_name in ('time', 'timestamp'):
|
70 | 125 | type_kwargs = dict(timezone=type_str.endswith("with time zone"))
|
71 |
| - # TODO: handle time/timestamp(p) precision |
72 |
| - return type_class(**type_kwargs) |
73 |
| - if type_name in ('array', 'map', 'row'): |
74 |
| - # TODO |
75 |
| - return sqltypes.NULLTYPE |
| 126 | + return type_class(**type_kwargs) # TODO: handle time/timestamp(p) precision |
76 | 127 | return type_class(*type_args)
|
0 commit comments