Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/databricks/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from databricks.sqlalchemy.base import DatabricksDialect
from databricks.sqlalchemy._types import TINYINT, TIMESTAMP, TIMESTAMP_NTZ
from databricks.sqlalchemy._types import (
TINYINT,
TIMESTAMP,
TIMESTAMP_NTZ,
DatabricksArray,
DatabricksMap,
)

__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ"]
__all__ = ["TINYINT", "TIMESTAMP", "TIMESTAMP_NTZ", "DatabricksArray", "DatabricksMap"]
76 changes: 76 additions & 0 deletions src/databricks/sqlalchemy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlalchemy
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.types import TypeDecorator, UserDefinedType

from databricks.sql.utils import ParamEscaper

Expand All @@ -26,6 +27,11 @@ def process_literal_param_hack(value: Any):
return value


def identity_processor(value):
"""This method returns the value itself, when no other processor is provided"""
return value


@compiles(sqlalchemy.types.Enum, "databricks")
@compiles(sqlalchemy.types.String, "databricks")
@compiles(sqlalchemy.types.Text, "databricks")
Expand Down Expand Up @@ -321,3 +327,73 @@ class TINYINT(sqlalchemy.types.TypeDecorator):
@compiles(TINYINT, "databricks")
def compile_tinyint(type_, compiler, **kw):
return "TINYINT"


class DatabricksArray(UserDefinedType):
"""
A custom array type that can wrap any other SQLAlchemy type.

Examples:
DatabricksArray(String) -> ARRAY<STRING>
DatabricksArray(Integer) -> ARRAY<INT>
DatabricksArray(CustomType) -> ARRAY<CUSTOM_TYPE>
"""

def __init__(self, item_type):
self.item_type = item_type() if isinstance(item_type, type) else item_type

def bind_processor(self, dialect):
item_processor = self.item_type.bind_processor(dialect)
if item_processor is None:
item_processor = identity_processor

def process(value):
return [item_processor(val) for val in value]

return process


@compiles(DatabricksArray, "databricks")
def compile_databricks_array(type_, compiler, **kw):
inner = compiler.process(type_.item_type, **kw)

return f"ARRAY<{inner}>"


class DatabricksMap(UserDefinedType):
"""
A custom map type that can wrap any other SQLAlchemy types for both key and value.

Examples:
DatabricksMap(String, String) -> MAP<STRING,STRING>
DatabricksMap(Integer, String) -> MAP<INT,STRING>
DatabricksMap(String, DatabricksArray(Integer)) -> MAP<STRING,ARRAY<INT>>
"""

def __init__(self, key_type, value_type):
self.key_type = key_type() if isinstance(key_type, type) else key_type
self.value_type = value_type() if isinstance(value_type, type) else value_type

def bind_processor(self, dialect):
key_processor = self.key_type.bind_processor(dialect)
value_processor = self.value_type.bind_processor(dialect)

if key_processor is None:
key_processor = identity_processor
if value_processor is None:
value_processor = identity_processor

def process(value):
return {
key_processor(key): value_processor(value)
for key, value in value.items()
}

return process


@compiles(DatabricksMap, "databricks")
def compile_databricks_map(type_, compiler, **kw):
key_type = compiler.process(type_.key_type, **kw)
value_type = compiler.process(type_.value_type, **kw)
return f"MAP<{key_type},{value_type}>"
Empty file.
211 changes: 211 additions & 0 deletions tests/test_local/e2e/test_complex_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from .test_setup import TestSetup
from sqlalchemy import (
Column,
BigInteger,
String,
Integer,
Numeric,
Boolean,
Date,
TIMESTAMP,
DateTime,
)
from collections.abc import Sequence
from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap
from sqlalchemy.orm import DeclarativeBase, Session
from sqlalchemy import select
from datetime import date, datetime, time, timedelta, timezone
import pandas as pd
import numpy as np
import decimal


class TestComplexTypes(TestSetup):
def _parse_to_common_type(self, value):
"""
Function to convert the :value passed into a common python datatype for comparison

Convertion fyi
MAP Datatype on server is returned as a list of tuples
Ex:
{"a":1,"b":2} -> [("a",1),("b",2)]

ARRAY Datatype on server is returned as a numpy array
Ex:
["a","b","c"] -> np.array(["a","b","c"],dtype=object)

Primitive datatype on server is returned as a numpy primitive
Ex:
1 -> np.int64(1)
2 -> np.int32(2)
"""
if value is None:
return None
elif isinstance(value, (Sequence, np.ndarray)) and not isinstance(
value, (str, bytes)
):
return tuple(value)
elif isinstance(value, dict):
return tuple(value.items())
elif isinstance(value, np.generic):
return value.item()
elif isinstance(value, decimal.Decimal):
return float(value)
else:
return value

def _recursive_compare(self, actual, expected):
"""
Function to compare the :actual and :expected values, recursively checks and ensures that all the data matches till the leaf level

Note: Complex datatype like MAP is not returned as a dictionary but as a list of tuples
"""
actual_parsed = self._parse_to_common_type(actual)
expected_parsed = self._parse_to_common_type(expected)

# Check if types are the same
if type(actual_parsed) != type(expected_parsed):
return False

# Handle lists or tuples
if isinstance(actual_parsed, (list, tuple)):
if len(actual_parsed) != len(expected_parsed):
return False
return all(
self._recursive_compare(o1, o2)
for o1, o2 in zip(actual_parsed, expected_parsed)
)

return actual_parsed == expected_parsed

def sample_array_table(self) -> tuple[DeclarativeBase, dict]:
class Base(DeclarativeBase):
pass

class ArrayTable(Base):
__tablename__ = "sqlalchemy_array_table"

int_col = Column(Integer, primary_key=True)
array_int_col = Column(DatabricksArray(Integer))
array_bigint_col = Column(DatabricksArray(BigInteger))
array_numeric_col = Column(DatabricksArray(Numeric(10, 2)))
array_string_col = Column(DatabricksArray(String))
array_boolean_col = Column(DatabricksArray(Boolean))
array_date_col = Column(DatabricksArray(Date))
array_datetime_col = Column(DatabricksArray(TIMESTAMP))
array_datetime_col_ntz = Column(DatabricksArray(DateTime))
array_tinyint_col = Column(DatabricksArray(TINYINT))

sample_data = {
"int_col": 1,
"array_int_col": [1, 2],
"array_bigint_col": [1234567890123456789, 2345678901234567890],
"array_numeric_col": [1.1, 2.2],
"array_string_col": ["a", "b"],
"array_boolean_col": [True, False],
"array_date_col": [date(2020, 12, 25), date(2021, 1, 2)],
"array_datetime_col": [
datetime(1991, 8, 3, 21, 30, 5, tzinfo=timezone(timedelta(hours=-8))),
datetime(1991, 8, 3, 21, 30, 5, tzinfo=timezone(timedelta(hours=-8))),
],
"array_datetime_col_ntz": [
datetime(1990, 12, 4, 6, 33, 41),
datetime(1990, 12, 4, 6, 33, 41),
],
"array_tinyint_col": [-100, 100],
}

return ArrayTable, sample_data

def sample_map_table(self) -> tuple[DeclarativeBase, dict]:
class Base(DeclarativeBase):
pass

class MapTable(Base):
__tablename__ = "sqlalchemy_map_table"

int_col = Column(Integer, primary_key=True)
map_int_col = Column(DatabricksMap(Integer, Integer))
map_bigint_col = Column(DatabricksMap(Integer, BigInteger))
map_numeric_col = Column(DatabricksMap(Integer, Numeric(10, 2)))
map_string_col = Column(DatabricksMap(Integer, String))
map_boolean_col = Column(DatabricksMap(Integer, Boolean))
map_date_col = Column(DatabricksMap(Integer, Date))
map_datetime_col = Column(DatabricksMap(Integer, TIMESTAMP))
map_datetime_col_ntz = Column(DatabricksMap(Integer, DateTime))
map_tinyint_col = Column(DatabricksMap(Integer, TINYINT))

sample_data = {
"int_col": 1,
"map_int_col": {1: 1},
"map_bigint_col": {1: 1234567890123456789},
"map_numeric_col": {1: 1.1},
"map_string_col": {1: "a"},
"map_boolean_col": {1: True},
"map_date_col": {1: date(2020, 12, 25)},
"map_datetime_col": {
1: datetime(1991, 8, 3, 21, 30, 5, tzinfo=timezone(timedelta(hours=-8)))
},
"map_datetime_col_ntz": {1: datetime(1990, 12, 4, 6, 33, 41)},
"map_tinyint_col": {1: -100},
}

return MapTable, sample_data

def test_insert_array_table_sqlalchemy(self):
table, sample_data = self.sample_array_table()

with self.table_context(table) as engine:
sa_obj = table(**sample_data)
session = Session(engine)
session.add(sa_obj)
session.commit()

stmt = select(table).where(table.int_col == 1)

result = session.scalar(stmt)

compare = {key: getattr(result, key) for key in sample_data.keys()}
assert self._recursive_compare(compare, sample_data)

def test_insert_map_table_sqlalchemy(self):
table, sample_data = self.sample_map_table()

with self.table_context(table) as engine:
sa_obj = table(**sample_data)
session = Session(engine)
session.add(sa_obj)
session.commit()

stmt = select(table).where(table.int_col == 1)

result = session.scalar(stmt)

compare = {key: getattr(result, key) for key in sample_data.keys()}
assert self._recursive_compare(compare, sample_data)

def test_array_table_creation_pandas(self):
table, sample_data = self.sample_array_table()

with self.table_context(table) as engine:
# Insert the data into the table
df = pd.DataFrame([sample_data])
df.to_sql(table.__tablename__, engine, if_exists="append", index=False)

# Read the data from the table
stmt = select(table)
df_result = pd.read_sql(stmt, engine)
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)

def test_map_table_creation_pandas(self):
table, sample_data = self.sample_map_table()

with self.table_context(table) as engine:
# Insert the data into the table
df = pd.DataFrame([sample_data])
df.to_sql(table.__tablename__, engine, if_exists="append", index=False)

# Read the data from the table
stmt = select(table)
df_result = pd.read_sql(stmt, engine)
assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data)
31 changes: 31 additions & 0 deletions tests/test_local/e2e/test_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest
from sqlalchemy import create_engine, Engine
from contextlib import contextmanager
from sqlalchemy.orm import DeclarativeBase, Session


class TestSetup:
@pytest.fixture(autouse=True)
def get_details(self, connection_details):
self.arguments = connection_details.copy()

def db_engine(self) -> Engine:
HOST = self.arguments["host"]
HTTP_PATH = self.arguments["http_path"]
ACCESS_TOKEN = self.arguments["access_token"]
CATALOG = self.arguments["catalog"]
SCHEMA = self.arguments["schema"]

connect_args = {"_user_agent_entry": "SQLAlchemy e2e Tests"}

conn_string = f"databricks://token:{ACCESS_TOKEN}@{HOST}?http_path={HTTP_PATH}&catalog={CATALOG}&schema={SCHEMA}"
return create_engine(conn_string, connect_args=connect_args)

@contextmanager
def table_context(self, table: DeclarativeBase):
engine = self.db_engine()
table.metadata.create_all(engine)
try:
yield engine
finally:
table.metadata.drop_all(engine)
20 changes: 19 additions & 1 deletion tests/test_local/test_ddl.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
from sqlalchemy import Column, MetaData, String, Table, create_engine
from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine
from sqlalchemy.schema import (
CreateTable,
DropColumnComment,
DropTableComment,
SetColumnComment,
SetTableComment,
)
from databricks.sqlalchemy import DatabricksArray, DatabricksMap


class DDLTestBase:
Expand Down Expand Up @@ -94,3 +95,20 @@ def test_alter_table_drop_comment(self, table_with_comment):
stmt = DropTableComment(table_with_comment)
output = self.compile(stmt)
assert output == "COMMENT ON TABLE martin IS NULL"


class TestTableComplexTypeDDL(DDLTestBase):
@pytest.fixture(scope="class")
def metadata(self) -> MetaData:
metadata = MetaData()
col1 = Column("array_array_string", DatabricksArray(DatabricksArray(String)))
col2 = Column("map_string_string", DatabricksMap(String, String))
table = Table("complex_type", metadata, col1, col2)
return metadata

def test_create_table_with_complex_type(self, metadata):
stmt = CreateTable(metadata.tables["complex_type"])
output = self.compile(stmt)

assert "array_array_string ARRAY<ARRAY<STRING>>" in output
assert "map_string_string MAP<STRING,STRING>" in output
Loading
Loading