Skip to content

Commit f3ded8b

Browse files
committed
migrated work for fabric connection
1 parent 7be1b00 commit f3ded8b

File tree

6 files changed

+71
-583
lines changed

6 files changed

+71
-583
lines changed
Lines changed: 41 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -1,175 +1,56 @@
1-
from typing import List, Optional
21

3-
import agate
4-
from dbt.adapters.base.relation import BaseRelation
5-
from dbt.adapters.cache import _make_ref_key_msg
6-
from dbt.adapters.sql import SQLAdapter
7-
from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME
8-
from dbt.events.functions import fire_event
9-
from dbt.events.types import SchemaCreation
102

113
from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn
124
from dbt.adapters.sqlserver.sql_server_configs import SQLServerConfigs
135
from dbt.adapters.sqlserver.sql_server_connection_manager import SQLServerConnectionManager
6+
# from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support
147

8+
# https://github.com/microsoft/dbt-fabric/blob/main/dbt/adapters/fabric/fabric_adapter.py
9+
from dbt.adapters.fabric import FabricAdapter
1510

16-
class SQLServerAdapter(SQLAdapter):
11+
12+
class SQLServerAdapter(FabricAdapter):
1713
ConnectionManager = SQLServerConnectionManager
1814
Column = SQLServerColumn
1915
AdapterSpecificConfigs = SQLServerConfigs
2016

21-
def create_schema(self, relation: BaseRelation) -> None:
22-
relation = relation.without_identifier()
23-
fire_event(SchemaCreation(relation=_make_ref_key_msg(relation)))
24-
macro_name = CREATE_SCHEMA_MACRO_NAME
25-
kwargs = {
26-
"relation": relation,
27-
}
28-
29-
if self.config.credentials.schema_authorization:
30-
kwargs["schema_authorization"] = self.config.credentials.schema_authorization
31-
macro_name = "sqlserver__create_schema_with_authorization"
32-
33-
self.execute_macro(macro_name, kwargs=kwargs)
34-
self.commit_if_has_connection()
17+
# _capabilities: CapabilityDict = CapabilityDict(
18+
# {
19+
# Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full),
20+
# Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full),
21+
# }
22+
# )
23+
24+
# region - these are implement in fabric but not in sqlserver
25+
# _capabilities: CapabilityDict = CapabilityDict(
26+
# {
27+
# Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full),
28+
# Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full),
29+
# }
30+
# )
31+
# CONSTRAINT_SUPPORT = {
32+
# ConstraintType.check: ConstraintSupport.NOT_SUPPORTED,
33+
# ConstraintType.not_null: ConstraintSupport.ENFORCED,
34+
# ConstraintType.unique: ConstraintSupport.ENFORCED,
35+
# ConstraintType.primary_key: ConstraintSupport.ENFORCED,
36+
# ConstraintType.foreign_key: ConstraintSupport.ENFORCED,
37+
# }
38+
39+
# @available.parse(lambda *a, **k: [])
40+
# def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
41+
# """Get a list of the Columns with names and data types from the given sql."""
42+
# _, cursor = self.connections.add_select_query(sql)
43+
44+
# columns = [
45+
# self.Column.create(
46+
# column_name, self.connections.data_type_code_to_name(column_type_code)
47+
# )
48+
# # https://peps.python.org/pep-0249/#description
49+
# for column_name, column_type_code, *_ in cursor.description
50+
# ]
51+
# return columns
52+
# endregion
3553

3654
@classmethod
3755
def date_function(cls):
3856
return "getdate()"
39-
40-
@classmethod
41-
def convert_text_type(cls, agate_table, col_idx):
42-
column = agate_table.columns[col_idx]
43-
# see https://github.com/fishtown-analytics/dbt/pull/2255
44-
lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()]
45-
max_len = max(lens) if lens else 64
46-
length = max_len if max_len > 16 else 16
47-
return "varchar({})".format(length)
48-
49-
@classmethod
50-
def convert_datetime_type(cls, agate_table, col_idx):
51-
return "datetime"
52-
53-
@classmethod
54-
def convert_boolean_type(cls, agate_table, col_idx):
55-
return "bit"
56-
57-
@classmethod
58-
def convert_number_type(cls, agate_table, col_idx):
59-
decimals = agate_table.aggregate(agate.MaxPrecision(col_idx))
60-
return "float" if decimals else "int"
61-
62-
@classmethod
63-
def convert_time_type(cls, agate_table, col_idx):
64-
return "datetime"
65-
66-
# Methods used in adapter tests
67-
def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str:
68-
# note: 'interval' is not supported for T-SQL
69-
# for backwards compatibility, we're compelled to set some sort of
70-
# default. A lot of searching has lead me to believe that the
71-
# '+ interval' syntax used in postgres/redshift is relatively common
72-
# and might even be the SQL standard's intention.
73-
return f"DATEADD({interval},{number},{add_to})"
74-
75-
def string_add_sql(
76-
self,
77-
add_to: str,
78-
value: str,
79-
location="append",
80-
) -> str:
81-
"""
82-
`+` is T-SQL's string concatenation operator
83-
"""
84-
if location == "append":
85-
return f"{add_to} + '{value}'"
86-
elif location == "prepend":
87-
return f"'{value}' + {add_to}"
88-
else:
89-
raise ValueError(f'Got an unexpected location value of "{location}"')
90-
91-
def get_rows_different_sql(
92-
self,
93-
relation_a: BaseRelation,
94-
relation_b: BaseRelation,
95-
column_names: Optional[List[str]] = None,
96-
except_operator: str = "EXCEPT",
97-
) -> str:
98-
"""
99-
note: using is not supported on Synapse so COLUMNS_EQUAL_SQL is adjsuted
100-
Generate SQL for a query that returns a single row with a two
101-
columns: the number of rows that are different between the two
102-
relations and the number of mismatched rows.
103-
"""
104-
# This method only really exists for test reasons.
105-
names: List[str]
106-
if column_names is None:
107-
columns = self.get_columns_in_relation(relation_a)
108-
names = sorted((self.quote(c.name) for c in columns))
109-
else:
110-
names = sorted((self.quote(n) for n in column_names))
111-
columns_csv = ", ".join(names)
112-
113-
sql = COLUMNS_EQUAL_SQL.format(
114-
columns=columns_csv,
115-
relation_a=str(relation_a),
116-
relation_b=str(relation_b),
117-
except_op=except_operator,
118-
)
119-
120-
return sql
121-
122-
def valid_incremental_strategies(self):
123-
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
124-
Not used to validate custom strategies defined by end users.
125-
"""
126-
return ["append", "delete+insert", "merge", "insert_overwrite"]
127-
128-
# This is for use in the test suite
129-
def run_sql_for_tests(self, sql, fetch, conn):
130-
cursor = conn.handle.cursor()
131-
try:
132-
cursor.execute(sql)
133-
if not fetch:
134-
conn.handle.commit()
135-
if fetch == "one":
136-
return cursor.fetchone()
137-
elif fetch == "all":
138-
return cursor.fetchall()
139-
else:
140-
return
141-
except BaseException:
142-
if conn.handle and not getattr(conn.handle, "closed", True):
143-
conn.handle.rollback()
144-
raise
145-
finally:
146-
conn.transaction_open = False
147-
148-
149-
COLUMNS_EQUAL_SQL = """
150-
with diff_count as (
151-
SELECT
152-
1 as id,
153-
COUNT(*) as num_missing FROM (
154-
(SELECT {columns} FROM {relation_a} {except_op}
155-
SELECT {columns} FROM {relation_b})
156-
UNION ALL
157-
(SELECT {columns} FROM {relation_b} {except_op}
158-
SELECT {columns} FROM {relation_a})
159-
) as a
160-
), table_a as (
161-
SELECT COUNT(*) as num_rows FROM {relation_a}
162-
), table_b as (
163-
SELECT COUNT(*) as num_rows FROM {relation_b}
164-
), row_count_diff as (
165-
select
166-
1 as id,
167-
table_a.num_rows - table_b.num_rows as difference
168-
from table_a, table_b
169-
)
170-
select
171-
row_count_diff.difference as row_count_difference,
172-
diff_count.num_missing as num_mismatched
173-
from row_count_diff
174-
join diff_count on row_count_diff.id = diff_count.id
175-
""".strip()
Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,5 @@
1-
from typing import Any, ClassVar, Dict
21

3-
from dbt.adapters.base import Column
2+
from dbt.adapters.fabric import FabricColumn
43

5-
6-
class SQLServerColumn(Column):
7-
TYPE_LABELS: ClassVar[Dict[str, str]] = {
8-
"STRING": "VARCHAR(MAX)",
9-
"TIMESTAMP": "DATETIMEOFFSET",
10-
"FLOAT": "FLOAT",
11-
"INTEGER": "INT",
12-
"BOOLEAN": "BIT",
13-
}
14-
15-
@classmethod
16-
def string_type(cls, size: int) -> str:
17-
return f"varchar({size if size > 0 else 'MAX'})"
18-
19-
def literal(self, value: Any) -> str:
20-
return "cast('{}' as {})".format(value, self.data_type)
4+
class SQLServerColumn(FabricColumn):
5+
...
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from dataclasses import dataclass
2-
from typing import Optional
3-
4-
from dbt.adapters.protocol import AdapterConfig
52

3+
from dbt.adapters.fabric import FabricConfigs
64

75
@dataclass
8-
class SQLServerConfigs(AdapterConfig):
9-
auto_provision_aad_principals: Optional[bool] = False
6+
class SQLServerConfigs(FabricConfigs):
7+
...

0 commit comments

Comments
 (0)