Skip to content

Commit 3ee757b

Browse files
committed
Support for complex params
1 parent 9f57279 commit 3ee757b

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from databricks.sqlalchemy.base import DatabricksDialect
2-
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
2+
from databricks.sqlalchemy._types import (
3+
TINYINT,
4+
TIMESTAMP,
5+
TIMESTAMP_NTZ,
6+
DatabricksArray,
7+
DatabricksMap,
8+
)
39

4-
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"]
10+
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]

src/databricks/sqlalchemy/_types.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sqlalchemy
66
from sqlalchemy.engine.interfaces import Dialect
77
from sqlalchemy.ext.compiler import compiles
8+
from sqlalchemy.types import TypeDecorator, UserDefinedType
89

910
from databricks.sql.utils import ParamEscaper
1011

@@ -321,3 +322,77 @@ class TINYINT(sqlalchemy.types.TypeDecorator):
321322
@compiles(TINYINT, "databricks")
322323
def compile_tinyint(type_, compiler, **kw):
323324
return "TINYINT"
325+
326+
327+
class DatabricksArray(UserDefinedType):
328+
"""
329+
A custom array type that can wrap any other SQLAlchemy type.
330+
331+
Examples:
332+
DatabricksArray(String) -> ARRAY<STRING>
333+
DatabricksArray(Integer) -> ARRAY<INT>
334+
DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE>
335+
"""
336+
337+
def __init__(self, item_type):
338+
self.item_type = item_type() if isinstance(item_type, type) else item_type
339+
340+
def get_col_spec(self, **kw):
341+
if isinstance(self.item_type, UserDefinedType):
342+
# If it's a UserDefinedType, call its get_col_spec directly
343+
inner_type = self.item_type.get_col_spec(**kw)
344+
elif isinstance(self.item_type, TypeDecorator):
345+
# If it's a TypeDecorator, we need to get its dialect implementation
346+
dialect = kw.get("type_expression", None)
347+
if dialect:
348+
dialect = dialect.dialect
349+
impl = self.item_type.load_dialect_impl(dialect)
350+
# Compile the implementation type
351+
inner_type = impl.compile(dialect=dialect)
352+
else:
353+
# Fallback if no dialect available
354+
inner_type = self.item_type.impl.__class__.__name__.upper()
355+
else:
356+
# For basic SQLAlchemy types, use class name
357+
inner_type = self.item_type.__class__.__name__.upper()
358+
359+
return f"ARRAY<{inner_type}>"
360+
361+
362+
class DatabricksMap(UserDefinedType):
363+
"""
364+
A custom map type that can wrap any other SQLAlchemy types for both key and value.
365+
366+
Examples:
367+
DatabricksMap(String, String) -> MAP<STRING,STRING>
368+
DatabricksMap(Integer, String) -> MAP<INT,STRING>
369+
DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>>
370+
"""
371+
372+
def __init__(self, key_type, value_type):
373+
self.key_type = key_type() if isinstance(key_type, type) else key_type
374+
self.value_type = value_type() if isinstance(value_type, type) else value_type
375+
376+
def get_col_spec(self, **kw):
377+
def process_type(type_obj):
378+
if isinstance(type_obj, UserDefinedType):
379+
# If it's a UserDefinedType, call its get_col_spec directly
380+
return type_obj.get_col_spec(**kw)
381+
elif isinstance(type_obj, TypeDecorator):
382+
# If it's a TypeDecorator, we need to get its dialect implementation
383+
dialect = kw.get("type_expression", None)
384+
if dialect:
385+
dialect = dialect.dialect
386+
impl = type_obj.load_dialect_impl(dialect)
387+
# Compile the implementation type
388+
return impl.compile(dialect=dialect)
389+
else:
390+
# Fallback if no dialect available
391+
return type_obj.impl.__class__.__name__.upper()
392+
else:
393+
# For basic SQLAlchemy types, use class name
394+
return type_obj.__class__.__name__.upper()
395+
396+
key_type = process_type(self.key_type)
397+
value_type = process_type(self.value_type)
398+
return f"MAP<{key_type},{value_type}>"

0 commit comments

Comments
 (0)