Skip to content

Commit ee8cc29

Browse files
committed
Added better parsing
1 parent 7b4cfe0 commit ee8cc29

File tree

3 files changed

+103
-51
lines changed

3 files changed

+103
-51
lines changed

src/databricks/sqlalchemy/_types.py

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -337,26 +337,12 @@ class DatabricksArray(UserDefinedType):
337337
def __init__(self, item_type):
338338
self.item_type = item_type() if isinstance(item_type, type) else item_type
339339

340-
def get_col_spec(self, **kw):
341-
if isinstance(self.item_type, UserDefinedType):
342-
# If it's a UserDefinedType, call its get_col_spec directly
343-
inner_type = self.item_type.get_col_spec(**kw)
344-
elif isinstance(self.item_type, TypeDecorator):
345-
# If it's a TypeDecorator, we need to get its dialect implementation
346-
dialect = kw.get("type_expression", None)
347-
if dialect:
348-
dialect = dialect.dialect
349-
impl = self.item_type.load_dialect_impl(dialect)
350-
# Compile the implementation type
351-
inner_type = impl.compile(dialect=dialect)
352-
else:
353-
# Fallback if no dialect available
354-
inner_type = self.item_type.impl.__class__.__name__.upper()
355-
else:
356-
# For basic SQLAlchemy types, use class name
357-
inner_type = self.item_type.__class__.__name__.upper()
358340

359-
return f"ARRAY<{inner_type}>"
341+
@compiles(DatabricksArray, "databricks")
342+
def compile_databricks_array(type_, compiler, **kw):
343+
inner = compiler.process(type_.item_type, **kw)
344+
345+
return f"ARRAY<{inner}>"
360346

361347

362348
class DatabricksMap(UserDefinedType):
@@ -373,26 +359,9 @@ def __init__(self, key_type, value_type):
373359
self.key_type = key_type() if isinstance(key_type, type) else key_type
374360
self.value_type = value_type() if isinstance(value_type, type) else value_type
375361

376-
def get_col_spec(self, **kw):
377-
def process_type(type_obj):
378-
if isinstance(type_obj, UserDefinedType):
379-
# If it's a UserDefinedType, call its get_col_spec directly
380-
return type_obj.get_col_spec(**kw)
381-
elif isinstance(type_obj, TypeDecorator):
382-
# If it's a TypeDecorator, we need to get its dialect implementation
383-
dialect = kw.get("type_expression", None)
384-
if dialect:
385-
dialect = dialect.dialect
386-
impl = type_obj.load_dialect_impl(dialect)
387-
# Compile the implementation type
388-
return impl.compile(dialect=dialect)
389-
else:
390-
# Fallback if no dialect available
391-
return type_obj.impl.__class__.__name__.upper()
392-
else:
393-
# For basic SQLAlchemy types, use class name
394-
return type_obj.__class__.__name__.upper()
395362

396-
key_type = process_type(self.key_type)
397-
value_type = process_type(self.value_type)
398-
return f"MAP<{key_type},{value_type}>"
363+
@compiles(DatabricksMap, "databricks")
364+
def compile_databricks_map(type_, compiler, **kw):
365+
key_type = compiler.process(type_.key_type, **kw)
366+
value_type = compiler.process(type_.value_type, **kw)
367+
return f"MAP<{key_type},{value_type}>"

tests/test_local/test_ddl.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import pytest
2-
from sqlalchemy import Column, MetaData, String, Table, Numeric, create_engine
2+
from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine
33
from sqlalchemy.schema import (
44
CreateTable,
55
DropColumnComment,
66
DropTableComment,
77
SetColumnComment,
88
SetTableComment,
99
)
10-
from databricks.sqlalchemy import DatabricksArray,DatabricksMap
10+
from databricks.sqlalchemy import DatabricksArray, DatabricksMap
11+
1112

1213
class DDLTestBase:
1314
engine = create_engine(
@@ -95,21 +96,19 @@ def test_alter_table_drop_comment(self, table_with_comment):
9596
output = self.compile(stmt)
9697
assert output == "COMMENT ON TABLE martin IS NULL"
9798

99+
98100
class TestTableComplexTypeDDL(DDLTestBase):
99-
@pytest.fixture
101+
@pytest.fixture(scope="class")
100102
def metadata(self) -> MetaData:
101103
metadata = MetaData()
102-
col1 = Column("array_array_string",DatabricksArray(DatabricksArray(String)))
103-
col2 = Column("map_string_string",DatabricksMap(String,String))
104-
col3 = Column("array_array_decimal",DatabricksArray(DatabricksArray(Numeric(10,2))))
105-
table = Table("complex_type", metadata, col1,col2,col3)
104+
col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String)))
105+
col2 = Column("map_string_string", DatabricksMap(String, String))
106+
table = Table("complex_type", metadata, col1, col2)
106107
return metadata
107-
108+
108109
def test_create_table_with_complex_type(self, metadata):
109110
stmt = CreateTable(metadata.tables["complex_type"])
110111
output = self.compile(stmt)
111112

112-
print(output)
113113
assert "array_array_string ARRAY<ARRAY<STRING>>" in output
114114
assert "map_string_string MAP<STRING,STRING>" in output
115-
assert "array_array_decimal ARRAY<ARRAY<DECIMAL(10,2)>>" in output

tests/test_local/test_parsing.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,29 @@
99
get_comment_from_dte_output,
1010
DatabricksSqlAlchemyParseException,
1111
)
12+
from sqlalchemy import (
13+
BigInteger,
14+
Boolean,
15+
Column,
16+
Date,
17+
DateTime,
18+
Integer,
19+
Numeric,
20+
String,
21+
Time,
22+
Uuid,
23+
)
24+
25+
from databricks.sqlalchemy import (
26+
DatabricksArray,
27+
TIMESTAMP,
28+
TINYINT,
29+
DatabricksMap,
30+
TIMESTAMP_NTZ,
31+
)
32+
from databricks.sqlalchemy import DatabricksDialect
1233

34+
dialect = DatabricksDialect()
1335

1436
# These are outputs from DESCRIBE TABLE EXTENDED
1537
@pytest.mark.parametrize(
@@ -158,3 +180,65 @@ def test_filter_dict_by_value(match, output):
158180

159181
def test_get_comment_from_dte_output():
160182
assert get_comment_from_dte_output(FMT_SAMPLE_DT_OUTPUT) == "some comment"
183+
184+
185+
def get_databricks_non_compound_types():
186+
return [
187+
Integer,
188+
String,
189+
Boolean,
190+
Date,
191+
DateTime,
192+
Time,
193+
Uuid,
194+
Numeric,
195+
TINYINT,
196+
TIMESTAMP,
197+
TIMESTAMP_NTZ,
198+
]
199+
200+
201+
@pytest.mark.parametrize("internal_type", get_databricks_non_compound_types())
202+
def test_array_parsing(internal_type):
203+
array_type = DatabricksArray(internal_type())
204+
205+
actual_parsed = array_type.compile(dialect=dialect)
206+
expected_parsed = "ARRAY<{}>".format(internal_type().compile(dialect=dialect))
207+
assert actual_parsed == expected_parsed
208+
209+
210+
@pytest.mark.parametrize("internal_type_1", get_databricks_non_compound_types())
211+
@pytest.mark.parametrize("internal_type_2", get_databricks_non_compound_types())
212+
def test_map_parsing(internal_type_1, internal_type_2):
213+
map_type = DatabricksMap(internal_type_1(), internal_type_2())
214+
215+
actual_parsed = map_type.compile(dialect=dialect)
216+
expected_parsed = "MAP<{},{}>".format(
217+
internal_type_1().compile(dialect=dialect),
218+
internal_type_2().compile(dialect=dialect),
219+
)
220+
assert actual_parsed == expected_parsed
221+
222+
223+
@pytest.mark.parametrize("internal_type", get_databricks_non_compound_types())
224+
def test_multilevel_array_type_parsing(internal_type):
225+
array_type = DatabricksArray(DatabricksArray(DatabricksArray(internal_type())))
226+
227+
actual_parsed = array_type.compile(dialect=dialect)
228+
expected_parsed = "ARRAY<ARRAY<ARRAY<{}>>>".format(
229+
internal_type().compile(dialect=dialect)
230+
)
231+
assert actual_parsed == expected_parsed
232+
233+
234+
@pytest.mark.parametrize("internal_type", get_databricks_non_compound_types())
235+
def test_multilevel_map_type_parsing(internal_type):
236+
map_type = DatabricksMap(
237+
String, DatabricksMap(String, DatabricksMap(String, internal_type()))
238+
)
239+
240+
actual_parsed = map_type.compile(dialect=dialect)
241+
expected_parsed = "MAP<STRING,MAP<STRING,MAP<STRING,{}>>>".format(
242+
internal_type().compile(dialect=dialect)
243+
)
244+
assert actual_parsed == expected_parsed

0 commit comments

Comments
 (0)