|
5 | 5 | import sqlalchemy |
6 | 6 | from sqlalchemy.engine.interfaces import Dialect |
7 | 7 | from sqlalchemy.ext.compiler import compiles |
| 8 | +from sqlalchemy.types import TypeDecorator, UserDefinedType |
8 | 9 |
|
9 | 10 | from databricks.sql.utils import ParamEscaper |
10 | 11 |
|
@@ -321,3 +322,77 @@ class TINYINT(sqlalchemy.types.TypeDecorator): |
321 | 322 | @compiles(TINYINT, "databricks") |
322 | 323 | def compile_tinyint(type_, compiler, **kw): |
323 | 324 | return "TINYINT" |
| 325 | + |
| 326 | + |
| 327 | +class DatabricksArray(UserDefinedType): |
| 328 | + """ |
| 329 | + A custom array type that can wrap any other SQLAlchemy type. |
| 330 | +
|
| 331 | + Examples: |
| 332 | + DatabricksArray(String) -> ARRAY<STRING> |
| 333 | + DatabricksArray(Integer) -> ARRAY<INT> |
| 334 | + DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE> |
| 335 | + """ |
| 336 | + |
| 337 | + def __init__(self, item_type): |
| 338 | + self.item_type = item_type() if isinstance(item_type, type) else item_type |
| 339 | + |
| 340 | + def get_col_spec(self, **kw): |
| 341 | + if isinstance(self.item_type, UserDefinedType): |
| 342 | + # If it's a UserDefinedType, call its get_col_spec directly |
| 343 | + inner_type = self.item_type.get_col_spec(**kw) |
| 344 | + elif isinstance(self.item_type, TypeDecorator): |
| 345 | + # If it's a TypeDecorator, we need to get its dialect implementation |
| 346 | + dialect = kw.get("type_expression", None) |
| 347 | + if dialect: |
| 348 | + dialect = dialect.dialect |
| 349 | + impl = self.item_type.load_dialect_impl(dialect) |
| 350 | + # Compile the implementation type |
| 351 | + inner_type = impl.compile(dialect=dialect) |
| 352 | + else: |
| 353 | + # Fallback if no dialect available |
| 354 | + inner_type = self.item_type.impl.__class__.__name__.upper() |
| 355 | + else: |
| 356 | + # For basic SQLAlchemy types, use class name |
| 357 | + inner_type = self.item_type.__class__.__name__.upper() |
| 358 | + |
| 359 | + return f"ARRAY<{inner_type}>" |
| 360 | + |
| 361 | + |
| 362 | +class DatabricksMap(UserDefinedType): |
| 363 | + """ |
| 364 | + A custom map type that can wrap any other SQLAlchemy types for both key and value. |
| 365 | +
|
| 366 | + Examples: |
| 367 | + DatabricksMap(String, String) -> MAP<STRING,STRING> |
| 368 | + DatabricksMap(Integer, String) -> MAP<INT,STRING> |
| 369 | + DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>> |
| 370 | + """ |
| 371 | + |
| 372 | + def __init__(self, key_type, value_type): |
| 373 | + self.key_type = key_type() if isinstance(key_type, type) else key_type |
| 374 | + self.value_type = value_type() if isinstance(value_type, type) else value_type |
| 375 | + |
| 376 | + def get_col_spec(self, **kw): |
| 377 | + def process_type(type_obj): |
| 378 | + if isinstance(type_obj, UserDefinedType): |
| 379 | + # If it's a UserDefinedType, call its get_col_spec directly |
| 380 | + return type_obj.get_col_spec(**kw) |
| 381 | + elif isinstance(type_obj, TypeDecorator): |
| 382 | + # If it's a TypeDecorator, we need to get its dialect implementation |
| 383 | + dialect = kw.get("type_expression", None) |
| 384 | + if dialect: |
| 385 | + dialect = dialect.dialect |
| 386 | + impl = type_obj.load_dialect_impl(dialect) |
| 387 | + # Compile the implementation type |
| 388 | + return impl.compile(dialect=dialect) |
| 389 | + else: |
| 390 | + # Fallback if no dialect available |
| 391 | + return type_obj.impl.__class__.__name__.upper() |
| 392 | + else: |
| 393 | + # For basic SQLAlchemy types, use class name |
| 394 | + return type_obj.__class__.__name__.upper() |
| 395 | + |
| 396 | + key_type = process_type(self.key_type) |
| 397 | + value_type = process_type(self.value_type) |
| 398 | + return f"MAP<{key_type},{value_type}>" |
0 commit comments