diff --git a/.github/workflows/integration-tests-sqlserver.yml b/.github/workflows/integration-tests-sqlserver.yml index 5c7694a6..d74274b4 100644 --- a/.github/workflows/integration-tests-sqlserver.yml +++ b/.github/workflows/integration-tests-sqlserver.yml @@ -18,7 +18,7 @@ jobs: name: Regular strategy: matrix: - python_version: ["3.9", "3.10", "3.11", "3.12"] + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] msodbc_version: ["17", "18"] sqlserver_version: ["2017", "2019", "2022"] collation: ["SQL_Latin1_General_CP1_CS_AS", "SQL_Latin1_General_CP1_CI_AS"] diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml index 63b1f5c4..9c96f15a 100644 --- a/.github/workflows/publish-docker.yml +++ b/.github/workflows/publish-docker.yml @@ -12,7 +12,7 @@ jobs: publish-docker-client: strategy: matrix: - python_version: ["3.9", "3.10", "3.11", "3.12"] + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] docker_target: ["msodbc17", "msodbc18"] runs-on: ubuntu-latest permissions: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 5acae556..a2f09804 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -18,7 +18,7 @@ jobs: name: Unit tests strategy: matrix: - python_version: ["3.9", "3.10", "3.11", "3.12"] + python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] runs-on: ubuntu-latest permissions: contents: read diff --git a/Makefile b/Makefile index 91863e5a..cac2581f 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ .DEFAULT_GOAL:=help +THREADS ?= auto .PHONY: dev dev: ## Installs adapter in develop mode along with development dependencies @@ -44,7 +45,7 @@ unit: ## Runs unit tests. .PHONY: functional functional: ## Runs functional tests. @\ - pytest -n auto -ra -v tests/functional + pytest -n $(THREADS) -ra -v tests/functional .PHONY: test test: ## Runs unit tests and code checks against staged changes. diff --git a/dbt/adapters/sqlserver/__init__.py b/dbt/adapters/sqlserver/__init__.py index 879ea74c..f1466f59 100644 --- a/dbt/adapters/sqlserver/__init__.py +++ b/dbt/adapters/sqlserver/__init__.py @@ -11,7 +11,7 @@ adapter=SQLServerAdapter, credentials=SQLServerCredentials, include_path=sqlserver.PACKAGE_PATH, - dependencies=["fabric"], + dependencies=[], ) __all__ = [ diff --git a/dbt/adapters/sqlserver/sqlserver_adapter.py b/dbt/adapters/sqlserver/sqlserver_adapter.py index 6f05c501..0f6ad130 100644 --- a/dbt/adapters/sqlserver/sqlserver_adapter.py +++ b/dbt/adapters/sqlserver/sqlserver_adapter.py @@ -1,24 +1,45 @@ -from typing import Optional +from typing import List, Optional -import dbt.exceptions +import agate +import dbt_common.exceptions +from dbt.adapters.base.column import Column as BaseColumn from dbt.adapters.base.impl import ConstraintSupport -from dbt.adapters.fabric import FabricAdapter -from dbt.contracts.graph.nodes import ConstraintType +from dbt.adapters.base.meta import available +from dbt.adapters.base.relation import BaseRelation +from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support +from dbt.adapters.events.types import SchemaCreation +from dbt.adapters.reference_keys import _make_ref_key_dict +from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME, SQLAdapter +from dbt_common.behavior_flags import BehaviorFlag +from dbt_common.contracts.constraints import ( + ColumnLevelConstraint, + ConstraintType, + ModelLevelConstraint, +) +from dbt_common.events.functions import fire_event from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn +from dbt.adapters.sqlserver.sqlserver_configs import SQLServerConfigs from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation -class SQLServerAdapter(FabricAdapter): +class SQLServerAdapter(SQLAdapter): """ Controls actual implmentation of adapter, and ability to override certain methods. """ ConnectionManager = SQLServerConnectionManager Column = SQLServerColumn + AdapterSpecificConfigs = SQLServerConfigs Relation = SQLServerRelation + _capabilities: CapabilityDict = CapabilityDict( + { + Capability.SchemaMetadataByRelations: CapabilitySupport(support=Support.Full), + Capability.TableLastModifiedMetadata: CapabilitySupport(support=Support.Full), + } + ) CONSTRAINT_SUPPORT = { ConstraintType.check: ConstraintSupport.ENFORCED, ConstraintType.not_null: ConstraintSupport.ENFORCED, @@ -27,13 +48,184 @@ class SQLServerAdapter(FabricAdapter): ConstraintType.foreign_key: ConstraintSupport.ENFORCED, } + @property + def _behavior_flags(self) -> List[BehaviorFlag]: + return [ + { + "name": "empty", + "default": False, + "description": ( + "When enabled, table and view materializations will be created as empty " + "structures (no data)." + ), + }, + ] + + @available.parse(lambda *a, **k: []) + def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: + """Get a list of the Columns with names and data types from the given sql.""" + _, cursor = self.connections.add_select_query(sql) + + columns = [ + self.Column.create( + column_name, self.connections.data_type_code_to_name(column_type_code) + ) + # https://peps.python.org/pep-0249/#description + for column_name, column_type_code, *_ in cursor.description + ] + return columns + + @classmethod + def convert_boolean_type(cls, agate_table, col_idx): + return "bit" + + @classmethod + def convert_datetime_type(cls, agate_table, col_idx): + return "datetime2(6)" + @classmethod - def render_model_constraint(cls, constraint) -> Optional[str]: + def convert_number_type(cls, agate_table, col_idx): + decimals = agate_table.aggregate(agate.MaxPrecision(col_idx)) + return "float" if decimals else "int" + + def create_schema(self, relation: BaseRelation) -> None: + relation = relation.without_identifier() + fire_event(SchemaCreation(relation=_make_ref_key_dict(relation))) + macro_name = CREATE_SCHEMA_MACRO_NAME + kwargs = { + "relation": relation, + } + + if self.config.credentials.schema_authorization: + kwargs["schema_authorization"] = self.config.credentials.schema_authorization + macro_name = "sqlserver__create_schema_with_authorization" + + self.execute_macro(macro_name, kwargs=kwargs) + self.commit_if_has_connection() + + @classmethod + def convert_text_type(cls, agate_table, col_idx): + column = agate_table.columns[col_idx] + # see https://github.com/fishtown-analytics/dbt/pull/2255 + lens = [len(d.encode("utf-8")) for d in column.values_without_nulls()] + max_len = max(lens) if lens else 64 + length = max_len if max_len > 16 else 16 + return "varchar({})".format(length) + + @classmethod + def convert_time_type(cls, agate_table, col_idx): + return "time(6)" + + @classmethod + def date_function(cls): + return "getdate()" + + # Methods used in adapter tests + def timestamp_add_sql(self, add_to: str, number: int = 1, interval: str = "hour") -> str: + # note: 'interval' is not supported for T-SQL + # for backwards compatibility, we're compelled to set some sort of + # default. A lot of searching has lead me to believe that the + # '+ interval' syntax used in postgres/redshift is relatively common + # and might even be the SQL standard's intention. + return f"DATEADD({interval},{number},{add_to})" + + def string_add_sql( + self, + add_to: str, + value: str, + location="append", + ) -> str: + """ + `+` is T-SQL's string concatenation operator + """ + if location == "append": + return f"{add_to} + '{value}'" + elif location == "prepend": + return f"'{value}' + {add_to}" + else: + raise ValueError(f'Got an unexpected location value of "{location}"') + + def get_rows_different_sql( + self, + relation_a: BaseRelation, + relation_b: BaseRelation, + column_names: Optional[List[str]] = None, + except_operator: str = "EXCEPT", + ) -> str: + """ + note: using is not supported on Synapse so COLUMNS_EQUAL_SQL is adjsuted + Generate SQL for a query that returns a single row with a two + columns: the number of rows that are different between the two + relations and the number of mismatched rows. + """ + # This method only really exists for test reasons. + names: List[str] + if column_names is None: + columns = self.get_columns_in_relation(relation_a) + names = sorted((self.quote(c.name) for c in columns)) + else: + names = sorted((self.quote(n) for n in column_names)) + columns_csv = ", ".join(names) + + if columns_csv == "": + columns_csv = "*" + + sql = COLUMNS_EQUAL_SQL.format( + columns=columns_csv, + relation_a=str(relation_a), + relation_b=str(relation_b), + except_op=except_operator, + ) + + return sql + + def valid_incremental_strategies(self): + """The set of standard builtin strategies which this adapter supports out-of-the-box. + Not used to validate custom strategies defined by end users. + """ + return ["append", "delete+insert", "merge", "microbatch"] + + # This is for use in the test suite + def run_sql_for_tests(self, sql, fetch, conn): + cursor = conn.handle.cursor() + try: + cursor.execute(sql) + if not fetch: + conn.handle.commit() + if fetch == "one": + return cursor.fetchone() + elif fetch == "all": + return cursor.fetchall() + else: + return + except BaseException: + if conn.handle and not getattr(conn.handle, "closed", True): + conn.handle.rollback() + raise + finally: + conn.transaction_open = False + + @available + @classmethod + def render_column_constraint(cls, constraint: ColumnLevelConstraint) -> Optional[str]: + rendered_column_constraint = None + if constraint.type == ConstraintType.not_null: + rendered_column_constraint = "not null " + else: + rendered_column_constraint = "" + + if rendered_column_constraint: + rendered_column_constraint = rendered_column_constraint.strip() + + return rendered_column_constraint + + @classmethod + def render_model_constraint(cls, constraint: ModelLevelConstraint) -> Optional[str]: constraint_prefix = "add constraint " column_list = ", ".join(constraint.columns) if constraint.name is None: - raise dbt.exceptions.DbtDatabaseError( + raise dbt_common.exceptions.DbtDatabaseError( "Constraint name cannot be empty. Provide constraint name - column " + column_list + " and run the project again." @@ -56,12 +248,31 @@ def render_model_constraint(cls, constraint) -> Optional[str]: else: return None - @classmethod - def date_function(cls): - return "getdate()" - def valid_incremental_strategies(self): - """The set of standard builtin strategies which this adapter supports out-of-the-box. - Not used to validate custom strategies defined by end users. - """ - return ["append", "delete+insert", "merge", "microbatch"] +COLUMNS_EQUAL_SQL = """ +with diff_count as ( + SELECT + 1 as id, + COUNT(*) as num_missing FROM ( + (SELECT {columns} FROM {relation_a} {except_op} + SELECT {columns} FROM {relation_b}) + UNION ALL + (SELECT {columns} FROM {relation_b} {except_op} + SELECT {columns} FROM {relation_a}) + ) as a +), table_a as ( + SELECT COUNT(*) as num_rows FROM {relation_a} +), table_b as ( + SELECT COUNT(*) as num_rows FROM {relation_b} +), row_count_diff as ( + select + 1 as id, + table_a.num_rows - table_b.num_rows as difference + from table_a, table_b +) +select + row_count_diff.difference as row_count_difference, + diff_count.num_missing as num_mismatched +from row_count_diff +join diff_count on row_count_diff.id = diff_count.id +""".strip() diff --git a/dbt/adapters/sqlserver/sqlserver_column.py b/dbt/adapters/sqlserver/sqlserver_column.py index 68ef98e3..9bdf3fcc 100644 --- a/dbt/adapters/sqlserver/sqlserver_column.py +++ b/dbt/adapters/sqlserver/sqlserver_column.py @@ -1,7 +1,67 @@ -from dbt.adapters.fabric import FabricColumn +from typing import Any, ClassVar, Dict +from dbt.adapters.base.column import Column +from dbt_common.exceptions import DbtRuntimeError + + +class SQLServerColumn(Column): + TYPE_LABELS: ClassVar[Dict[str, str]] = { + "STRING": "VARCHAR(8000)", + "VARCHAR": "VARCHAR(8000)", + "CHAR": "CHAR(1)", + "NCHAR": "CHAR(1)", + "NVARCHAR": "VARCHAR(8000)", + "TIMESTAMP": "DATETIME2(6)", + "DATETIME2": "DATETIME2(6)", + "DATETIME2(6)": "DATETIME2(6)", + "DATE": "DATE", + "TIME": "TIME(6)", + "FLOAT": "FLOAT", + "REAL": "REAL", + "INT": "INT", + "INTEGER": "INT", + "BIGINT": "BIGINT", + "SMALLINT": "SMALLINT", + "TINYINT": "SMALLINT", + "BIT": "BIT", + "BOOLEAN": "BIT", + "DECIMAL": "DECIMAL", + "NUMERIC": "NUMERIC", + "MONEY": "DECIMAL", + "SMALLMONEY": "DECIMAL", + "UNIQUEIDENTIFIER": "UNIQUEIDENTIFIER", + "VARBINARY": "VARBINARY(MAX)", + "BINARY": "BINARY(1)", + } + + @classmethod + def string_type(cls, size: int) -> str: + return f"varchar({size if size > 0 else '8000'})" + + def literal(self, value: Any) -> str: + return "cast('{}' as {})".format(value, self.data_type) + + @property + def data_type(self) -> str: + # Always enforce datetime2 precision + if self.dtype.lower() == "datetime2": + return "datetime2(6)" + if self.is_string(): + return self.string_type(self.string_size()) + elif self.is_numeric(): + return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) + else: + return self.dtype + + def is_string(self) -> bool: + return self.dtype.lower() in ["varchar", "char"] + + def is_number(self): + return any([self.is_integer(), self.is_numeric(), self.is_float()]) + + def is_float(self): + return self.dtype.lower() in ["float", "real"] -class SQLServerColumn(FabricColumn): def is_integer(self) -> bool: return self.dtype.lower() in [ # real types @@ -20,3 +80,19 @@ def is_integer(self) -> bool: "serial8", "int", ] + + def is_numeric(self) -> bool: + return self.dtype.lower() in ["numeric", "decimal", "money", "smallmoney"] + + def string_size(self) -> int: + if not self.is_string(): + raise DbtRuntimeError("Called string_size() on non-string field!") + if self.char_size is None: + return 8000 + else: + return int(self.char_size) + + def can_expand_to(self, other_column: "SQLServerColumn") -> bool: + if not self.is_string() or not other_column.is_string(): + return False + return other_column.string_size() > self.string_size() diff --git a/dbt/adapters/sqlserver/sqlserver_configs.py b/dbt/adapters/sqlserver/sqlserver_configs.py index 35ce4262..bf6d2d1e 100644 --- a/dbt/adapters/sqlserver/sqlserver_configs.py +++ b/dbt/adapters/sqlserver/sqlserver_configs.py @@ -1,8 +1,9 @@ from dataclasses import dataclass +from typing import Optional -from dbt.adapters.fabric import FabricConfigs +from dbt.adapters.protocol import AdapterConfig @dataclass -class SQLServerConfigs(FabricConfigs): - pass +class SQLServerConfigs(AdapterConfig): + auto_provision_aad_principals: Optional[bool] = False diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index a91baeb1..a4c5c347 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -1,27 +1,168 @@ -import dbt_common.exceptions # noqa +import datetime as dt +import struct +import time +from contextlib import contextmanager +from itertools import chain, repeat +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union + +import agate +import dbt_common.exceptions import pyodbc from azure.core.credentials import AccessToken -from azure.identity import ClientSecretCredential, ManagedIdentityCredential -from dbt.adapters.contracts.connection import Connection, ConnectionState -from dbt.adapters.events.logging import AdapterLogger -from dbt.adapters.fabric import FabricConnectionManager -from dbt.adapters.fabric.fabric_connection_manager import ( - AZURE_AUTH_FUNCTIONS as AZURE_AUTH_FUNCTIONS_FABRIC, -) -from dbt.adapters.fabric.fabric_connection_manager import ( - AZURE_CREDENTIAL_SCOPE, - bool_to_connection_string_arg, - get_pyodbc_attrs_before_accesstoken, - get_pyodbc_attrs_before_credentials, +from azure.identity import ( + AzureCliCredential, + ClientSecretCredential, + DefaultAzureCredential, + EnvironmentCredential, + ManagedIdentityCredential, ) +from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus +from dbt.adapters.sql.connections import SQLConnectionManager +from dbt_common.clients.agate_helper import empty_table +from dbt_common.events.contextvars import get_node_info +from dbt_common.events.functions import fire_event +from dbt_common.utils.casting import cast_to_str from dbt.adapters.sqlserver import __version__ from dbt.adapters.sqlserver.sqlserver_credentials import SQLServerCredentials +_TOKEN: Optional[AccessToken] = None +AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default" +AZURE_AUTH_FUNCTION_TYPE = Callable[[SQLServerCredentials, Optional[str]], AccessToken] + logger = AdapterLogger("sqlserver") +# https://github.com/mkleehammer/pyodbc/wiki/Data-Types +datatypes = { + "str": "varchar", + "uuid.UUID": "uniqueidentifier", + "uuid": "uniqueidentifier", + "float": "bigint", + "int": "int", + "bytes": "varbinary", + "bytearray": "varbinary", + "bool": "bit", + "datetime.date": "date", + "datetime.datetime": "datetime2(6)", + "datetime.time": "time", + "decimal.Decimal": "decimal", +} + + +def convert_bytes_to_mswindows_byte_string(value: bytes) -> bytes: + """ + Convert bytes to a Microsoft windows byte string. + + Parameters + ---------- + value : bytes + The bytes. + + Returns + ------- + out : bytes + The Microsoft byte string. + """ + encoded_bytes = bytes(chain.from_iterable(zip(value, repeat(0)))) + return struct.pack(" bytes: + """ + Convert an access token to a Microsoft windows byte string. + + Parameters + ---------- + token : AccessToken + The token. + + Returns + ------- + out : bytes + The Microsoft byte string. + """ + value = bytes(token.token, "UTF-8") + return convert_bytes_to_mswindows_byte_string(value) + + +def get_cli_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessToken: + """ + Get an Azure access token using the CLI credentials + + First login with: + + ```bash + az login + ``` + + Parameters + ---------- + credentials: SQLServerCredentials + The credentials. + + Returns + ------- + out : AccessToken + Access token. + """ + _ = credentials + token = AzureCliCredential().get_token( + scope, timeout=getattr(credentials, "login_timeout", None) + ) + return token + + +def get_auto_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessToken: + """ + Get an Azure access token automatically through azure-identity + + Parameters + ----------- + credentials: SQLServerCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + token = DefaultAzureCredential().get_token( + scope, timeout=getattr(credentials, "login_timeout", None) + ) + return token + + +def get_environment_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessToken: + """ + Get an Azure access token by reading environment variables + + Parameters + ----------- + credentials: SQLServerCredentials + Credentials. + + Returns + ------- + out : AccessToken + The access token. + """ + token = EnvironmentCredential().get_token( + scope, timeout=getattr(credentials, "login_timeout", None) + ) + return token + -def get_msi_access_token(credentials: SQLServerCredentials) -> AccessToken: +def get_msi_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessToken: """ Get an Azure access token from the system's managed identity @@ -35,11 +176,14 @@ def get_msi_access_token(credentials: SQLServerCredentials) -> AccessToken: out : AccessToken The access token. """ - token = ManagedIdentityCredential().get_token(AZURE_CREDENTIAL_SCOPE) + _ = credentials + token = ManagedIdentityCredential().get_token(scope) return token -def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken: +def get_sp_access_token( + credentials: SQLServerCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE +) -> AccessToken: """ Get an Azure access token using the SP credentials. @@ -53,6 +197,7 @@ def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken: out : AccessToken The access token. """ + _ = scope token = ClientSecretCredential( str(credentials.tenant_id), str(credentials.client_id), @@ -61,16 +206,146 @@ def get_sp_access_token(credentials: SQLServerCredentials) -> AccessToken: return token -AZURE_AUTH_FUNCTIONS = { - **AZURE_AUTH_FUNCTIONS_FABRIC, +AZURE_AUTH_FUNCTIONS: Mapping[str, AZURE_AUTH_FUNCTION_TYPE] = { + "cli": get_cli_access_token, + "auto": get_auto_access_token, + "environment": get_environment_access_token, "serviceprincipal": get_sp_access_token, "msi": get_msi_access_token, } -class SQLServerConnectionManager(FabricConnectionManager): +def get_pyodbc_attrs_before_credentials(credentials: SQLServerCredentials) -> Dict: + """ + Get the pyodbc attributes for authentication. + + Parameters + ---------- + credentials : SQLServerCredentials + Credentials. + + Returns + ------- + Dict + The pyodbc attributes for authentication. + """ + global _TOKEN + sql_copt_ss_access_token = 1256 # ODBC constant for access token + MAX_REMAINING_TIME = 300 + + if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS: + if not _TOKEN or (_TOKEN.expires_on - time.time() < MAX_REMAINING_TIME): + _TOKEN = AZURE_AUTH_FUNCTIONS[credentials.authentication.lower()]( + credentials, AZURE_CREDENTIAL_SCOPE + ) + token_bytes = convert_access_token_to_mswindows_byte_string(_TOKEN) + return {sql_copt_ss_access_token: token_bytes} + + if credentials.authentication.lower() == "activedirectoryaccesstoken": + if credentials.access_token is None or credentials.access_token_expires_on is None: + raise ValueError( + ( + "Access token and access token expiry are " + "required for ActiveDirectoryAccessToken authentication." + ) + ) + _TOKEN = AccessToken( + token=credentials.access_token, + expires_on=int( + time.time() + 4500.0 + if credentials.access_token_expires_on == 0 + else credentials.access_token_expires_on + ), + ) + return {sql_copt_ss_access_token: convert_access_token_to_mswindows_byte_string(_TOKEN)} + + return {} + + +def bool_to_connection_string_arg(key: str, value: bool) -> str: + """ + Convert a boolean to a connection string argument. + + Parameters + ---------- + key : str + The key to use in the connection string. + value : bool + The boolean to convert. + + Returns + ------- + out : str + The connection string argument. + """ + return f'{key}={"Yes" if value else "No"}' + + +def byte_array_to_datetime(value: bytes) -> dt.datetime: + """ + Converts a DATETIMEOFFSET byte array to a timezone-aware datetime object + + Parameters + ---------- + value : buffer + A binary value conforming to SQL_SS_TIMESTAMPOFFSET_STRUCT + + Returns + ------- + out : datetime + + Source + ------ + SQL_SS_TIMESTAMPOFFSET datatype and SQL_SS_TIMESTAMPOFFSET_STRUCT layout: + https://learn.microsoft.com/sql/relational-databases/native-client-odbc-date-time/data-type-support-for-odbc-date-and-time-improvements + """ + # unpack 20 bytes of data into a tuple of 9 values + tup = struct.unpack("<6hI2h", value) + + # construct a datetime object + return dt.datetime( + year=tup[0], + month=tup[1], + day=tup[2], + hour=tup[3], + minute=tup[4], + second=tup[5], + microsecond=tup[6] // 1000, # https://bugs.python.org/issue15443 + tzinfo=dt.timezone(dt.timedelta(hours=tup[7], minutes=tup[8])), + ) + + +class SQLServerConnectionManager(SQLConnectionManager): TYPE = "sqlserver" + @contextmanager + def exception_handler(self, sql): + try: + yield + + except pyodbc.DatabaseError as e: + logger.debug("Database error: {}".format(str(e))) + + try: + # attempt to release the connection + self.release() + except pyodbc.Error: + logger.debug("Failed to release connection!") + + raise dbt_common.exceptions.DbtDatabaseError(str(e).strip()) from e + + except Exception as e: + logger.debug(f"Error running SQL: {sql}") + logger.debug("Rolling back transaction.") + self.release() + if isinstance(e, dbt_common.exceptions.DbtRuntimeError): + # during a sql query, an internal to dbt exception was raised. + # this sounds a lot like a signal handler and probably has + # useful information, so raise it without modification. + raise + + raise dbt_common.exceptions.DbtRuntimeError(e) + @classmethod def open(cls, connection: Connection) -> Connection: if connection.state == ConnectionState.OPEN: @@ -78,10 +353,6 @@ def open(cls, connection: Connection) -> Connection: return connection credentials = cls.get_credentials(connection.credentials) - if credentials.authentication != "sql": - return super().open(connection) - - # sql login authentication con_str = [f"DRIVER={{{credentials.driver}}}"] @@ -93,11 +364,38 @@ def open(cls, connection: Connection) -> Connection: con_str.append(f"SERVER={credentials.host},{credentials.port}") con_str.append(f"Database={credentials.database}") + con_str.append("Pooling=true") + + # Enabling trace flag + if credentials.trace_flag: + con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_ON") + else: + con_str.append("SQL_ATTR_TRACE=SQL_OPT_TRACE_OFF") assert credentials.authentication is not None - con_str.append(f"UID={{{credentials.UID}}}") - con_str.append(f"PWD={{{credentials.PWD}}}") + # Access token authentication does not additional connection string parameters. + # The access token is passed in the pyodbc attributes. + if ( + "ActiveDirectory" in credentials.authentication + and credentials.authentication != "ActiveDirectoryAccessToken" + ): + con_str.append(f"Authentication={credentials.authentication}") + + if credentials.authentication == "ActiveDirectoryPassword": + con_str.append(f"UID={{{credentials.UID}}}") + con_str.append(f"PWD={{{credentials.PWD}}}") + if credentials.authentication == "ActiveDirectoryServicePrincipal": + con_str.append(f"UID={{{credentials.client_id}}}") + con_str.append(f"PWD={{{credentials.client_secret}}}") + elif credentials.authentication == "ActiveDirectoryInteractive": + con_str.append(f"UID={{{credentials.UID}}}") + + elif credentials.windows_login: + con_str.append("trusted_connection=Yes") + elif credentials.authentication == "sql": + con_str.append(f"UID={{{credentials.UID}}}") + con_str.append(f"PWD={{{credentials.PWD}}}") # https://docs.microsoft.com/en-us/sql/relational-databases/native-client/features/using-encryption-without-validation?view=sql-server-ver15 assert credentials.encrypt is not None @@ -112,6 +410,19 @@ def open(cls, connection: Connection) -> Connection: application_name = f"dbt-{credentials.type}/{plugin_version}" con_str.append(f"APP={application_name}") + try: + con_str.append("ConnectRetryCount=3") + con_str.append("ConnectRetryInterval=10") + + except Exception as e: + logger.debug( + ( + "Retry count should be a integer value. " + "Skipping retries in the connection string." + ), + str(e), + ) + con_str_concat = ";".join(con_str) index = [] @@ -135,11 +446,10 @@ def open(cls, connection: Connection) -> Connection: def connect(): logger.debug(f"Using connection string: {con_str_display}") + pyodbc.pooling = True - if credentials.authentication == "ActiveDirectoryAccessToken": - attrs_before = get_pyodbc_attrs_before_accesstoken(credentials.access_token) - else: - attrs_before = get_pyodbc_attrs_before_credentials(credentials) + # pyodbc attributes includes the access token provided by the user if required. + attrs_before = get_pyodbc_attrs_before_credentials(credentials) handle = pyodbc.connect( con_str_concat, @@ -151,10 +461,175 @@ def connect(): logger.debug(f"Connected to db: {credentials.database}") return handle - return cls.retry_connection( + conn = cls.retry_connection( connection, connect=connect, logger=logger, retry_limit=credentials.retries, retryable_exceptions=retryable_exceptions, ) + + return conn + + def cancel(self, connection: Connection): + logger.debug("Cancel query") + + def add_begin_query(self): + # return self.add_query('BEGIN TRANSACTION', auto_begin=False) + pass + + def add_commit_query(self): + # return self.add_query('COMMIT TRANSACTION', auto_begin=False) + pass + + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False, + retryable_exceptions: Tuple[Type[Exception], ...] = (), + retry_limit: int = 2, + ) -> Tuple[Connection, Any]: + """ + Retry function encapsulated here to avoid commitment to some + user-facing interface. Right now, Redshift commits to a 1 second + retry timeout so this serves as a default. + """ + + def _execute_query_with_retry( + cursor: Any, + sql: str, + bindings: Optional[Any], + retryable_exceptions: Tuple[Type[Exception], ...], + retry_limit: int, + attempt: int, + ): + """ + A success sees the try exit cleanly and avoid any recursive + retries. Failure begins a sleep and retry routine. + """ + try: + # pyodbc does not handle a None type binding! + if bindings is None: + cursor.execute(sql) + else: + bindings = [ + binding if not isinstance(binding, dt.datetime) else binding.isoformat() + for binding in bindings + ] + cursor.execute(sql, bindings) + except retryable_exceptions as e: + # Cease retries and fail when limit is hit. + if attempt >= retry_limit: + raise e + + fire_event( + AdapterEventDebug( + message=( + f"Got a retryable error {type(e)}. {retry_limit-attempt} " + "retries left. Retrying in 1 second.\n" + f"Error:\n{e}" + ) + ) + ) + time.sleep(1) + + return _execute_query_with_retry( + cursor=cursor, + sql=sql, + bindings=bindings, + retryable_exceptions=retryable_exceptions, + retry_limit=retry_limit, + attempt=attempt + 1, + ) + + connection = self.get_thread_connection() + + if auto_begin and connection.transaction_open is False: + self.begin() + + fire_event( + ConnectionUsed( + conn_type=self.TYPE, + conn_name=cast_to_str(connection.name), + node_info=get_node_info(), + ) + ) + + with self.exception_handler(sql): + if abridge_sql_log: + log_sql = "{}...".format(sql[:512]) + else: + log_sql = sql + + fire_event( + SQLQuery( + conn_name=cast_to_str(connection.name), sql=log_sql, node_info=get_node_info() + ) + ) + + pre = time.time() + + cursor = connection.handle.cursor() + credentials = self.get_credentials(connection.credentials) + + _execute_query_with_retry( + cursor=cursor, + sql=sql, + bindings=bindings, + retryable_exceptions=retryable_exceptions, + retry_limit=credentials.retries if credentials.retries > 3 else retry_limit, + attempt=1, + ) + + # convert DATETIMEOFFSET binary structures to datetime ojbects + # https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 + connection.handle.add_output_converter(-155, byte_array_to_datetime) + + fire_event( + SQLQueryStatus( + status=str(self.get_response(cursor)), + elapsed=round((time.time() - pre)), + node_info=get_node_info(), + ) + ) + + return connection, cursor + + @classmethod + def get_credentials(cls, credentials: SQLServerCredentials) -> SQLServerCredentials: + return credentials + + @classmethod + def get_response(cls, cursor: Any) -> AdapterResponse: + message = "OK" + rows = cursor.rowcount + return AdapterResponse( + _message=message, + rows_affected=rows, + ) + + @classmethod + def data_type_code_to_name(cls, type_code: Union[str, str]) -> str: + data_type = str(type_code)[ + str(type_code).index("'") + 1 : str(type_code).rindex("'") # noqa: E203 + ] + return datatypes[data_type] + + def execute( + self, sql: str, auto_begin: bool = True, fetch: bool = False, limit: Optional[int] = None + ) -> Tuple[AdapterResponse, agate.Table]: + sql = self._add_query_comment(sql) + _, cursor = self.add_query(sql, auto_begin) + response = self.get_response(cursor) + if fetch: + while cursor.description is None: + if not cursor.nextset(): + break + table = self.get_result_from_cursor(cursor, limit) + else: + table = empty_table() + while cursor.nextset(): + pass + return response, table diff --git a/dbt/adapters/sqlserver/sqlserver_credentials.py b/dbt/adapters/sqlserver/sqlserver_credentials.py index bf1f5075..37bba77e 100644 --- a/dbt/adapters/sqlserver/sqlserver_credentials.py +++ b/dbt/adapters/sqlserver/sqlserver_credentials.py @@ -1,22 +1,74 @@ from dataclasses import dataclass from typing import Optional -from dbt.adapters.fabric import FabricCredentials +from dbt.adapters.contracts.connection import Credentials @dataclass -class SQLServerCredentials(FabricCredentials): - """ - Defines database specific credentials that get added to - profiles.yml to connect to new adapter - """ - +class SQLServerCredentials(Credentials): + driver: str + host: str + database: str + schema: str + UID: Optional[str] = None + PWD: Optional[str] = None port: Optional[int] = 1433 - authentication: Optional[str] = "sql" + windows_login: Optional[bool] = False + trace_flag: Optional[bool] = False + tenant_id: Optional[str] = None + client_id: Optional[str] = None + client_secret: Optional[str] = None + access_token: Optional[str] = None + access_token_expires_on: Optional[int] = 0 + authentication: str = "sql" + encrypt: Optional[bool] = True + trust_cert: Optional[bool] = False + retries: int = 3 + schema_authorization: Optional[str] = None + login_timeout: Optional[int] = 0 + query_timeout: Optional[int] = 0 + + _ALIASES = { + "user": "UID", + "username": "UID", + "pass": "PWD", + "password": "PWD", + "server": "host", + "trusted_connection": "windows_login", + "auth": "authentication", + "app_id": "client_id", + "app_secret": "client_secret", + "TrustServerCertificate": "trust_cert", + "schema_auth": "schema_authorization", + "SQL_ATTR_TRACE": "trace_flag", + } @property def type(self): return "sqlserver" def _connection_keys(self): - return super()._connection_keys() + ("port",) + if self.windows_login is True: + self.authentication = "Windows Login" + + if self.authentication.lower().strip() == "serviceprincipal": + self.authentication = "ActiveDirectoryServicePrincipal" + + return ( + "server", + "port", + "database", + "schema", + "UID", + "authentication", + "retries", + "login_timeout", + "query_timeout", + "trace_flag", + "encrypt", + "trust_cert", + ) + + @property + def unique_field(self): + return self.host diff --git a/dbt/include/sqlserver/macros/adapter/columns.sql b/dbt/include/sqlserver/macros/adapter/columns.sql deleted file mode 100644 index a98750e7..00000000 --- a/dbt/include/sqlserver/macros/adapter/columns.sql +++ /dev/null @@ -1,50 +0,0 @@ -{% macro sqlserver__get_empty_subquery_sql(select_sql, select_sql_header=none) %} - {% if select_sql.strip().lower().startswith('with') %} - {{ select_sql }} - {% else -%} - select * from ( - {{ select_sql }} - ) dbt_sbq_tmp - where 1 = 0 - {%- endif -%} - -{% endmacro %} - -{% macro sqlserver__get_columns_in_query(select_sql) %} - {% set query_label = apply_label() %} - {% call statement('get_columns_in_query', fetch_result=True, auto_begin=False) -%} - select TOP 0 * from ( - {{ select_sql }} - ) as __dbt_sbq - where 0 = 1 - {{ query_label }} - {% endcall %} - - {{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }} -{% endmacro %} - -{% macro sqlserver__alter_column_type(relation, column_name, new_column_type) %} - - {%- set tmp_column = column_name + "__dbt_alter" -%} - {% set alter_column_type %} - alter {{ relation.type }} {{ relation }} add "{{ tmp_column }}" {{ new_column_type }}; - {%- endset %} - - {% set update_column %} - update {{ relation }} set "{{ tmp_column }}" = "{{ column_name }}"; - {%- endset %} - - {% set drop_column %} - alter {{ relation.type }} {{ relation }} drop column "{{ column_name }}"; - {%- endset %} - - {% set rename_column %} - exec sp_rename '{{ relation | replace('"', '') }}.{{ tmp_column }}', '{{ column_name }}', 'column' - {%- endset %} - - {% do run_query(alter_column_type) %} - {% do run_query(update_column) %} - {% do run_query(drop_column) %} - {% do run_query(rename_column) %} - -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/metadata.sql b/dbt/include/sqlserver/macros/adapter/metadata.sql deleted file mode 100644 index ac8981c9..00000000 --- a/dbt/include/sqlserver/macros/adapter/metadata.sql +++ /dev/null @@ -1,8 +0,0 @@ -{% macro apply_label() %} - {{ log (config.get('query_tag','dbt-sqlserver'))}} - {%- set query_label = config.get('query_tag','dbt-sqlserver') -%} - OPTION (LABEL = '{{query_label}}'); -{% endmacro %} - -{% macro default__information_schema_hints() %}{% endmacro %} -{% macro sqlserver__information_schema_hints() %}with (nolock){% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/relation.sql b/dbt/include/sqlserver/macros/adapter/relation.sql deleted file mode 100644 index 57defbd1..00000000 --- a/dbt/include/sqlserver/macros/adapter/relation.sql +++ /dev/null @@ -1,5 +0,0 @@ -{% macro sqlserver__truncate_relation(relation) %} - {% call statement('truncate_relation') -%} - truncate table {{ relation }} - {%- endcall %} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/schemas.sql b/dbt/include/sqlserver/macros/adapter/schemas.sql deleted file mode 100644 index 8317d6cb..00000000 --- a/dbt/include/sqlserver/macros/adapter/schemas.sql +++ /dev/null @@ -1,5 +0,0 @@ - -{% macro sqlserver__drop_schema_named(schema_name) %} - {% set schema_relation = api.Relation.create(schema=schema_name) %} - {{ adapter.drop_schema(schema_relation) }} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/apply_grants.sql b/dbt/include/sqlserver/macros/adapters/apply_grants.sql new file mode 100644 index 00000000..a9ccbd9c --- /dev/null +++ b/dbt/include/sqlserver/macros/adapters/apply_grants.sql @@ -0,0 +1,71 @@ +{% macro sqlserver__apply_grants(relation, grant_config, should_revoke=True) %} + {#-- If grant_config is {} or None, this is a no-op --#} + {% if grant_config %} + {% if should_revoke %} + {#-- We think previous grants may have carried over --#} + {#-- Show current grants and calculate diffs --#} + {% set current_grants_table = run_query(get_show_grant_sql(relation)) %} + {% set current_grants_dict = adapter.standardize_grants_dict(current_grants_table) %} + {% set needs_granting = diff_of_two_dicts(grant_config, current_grants_dict) %} + {% set needs_revoking = diff_of_two_dicts(current_grants_dict, grant_config) %} + {% if not (needs_granting or needs_revoking) %} + {{ log('On ' ~ relation ~': All grants are in place, no revocation or granting needed.')}} + {% endif %} + {% else %} + {#-- We don't think there's any chance of previous grants having carried over. --#} + {#-- Jump straight to granting what the user has configured. --#} + {% set needs_revoking = {} %} + {% set needs_granting = grant_config %} + {% endif %} + {% if needs_granting or needs_revoking %} + {% set revoke_statement_list = get_dcl_statement_list(relation, needs_revoking, get_revoke_sql) %} + + {% if config.get('auto_provision_aad_principals', False) %} + {% set provision_statement_list = get_dcl_statement_list(relation, needs_granting, get_provision_sql) %} + {% else %} + {% set provision_statement_list = [] %} + {% endif %} + + {% set grant_statement_list = get_dcl_statement_list(relation, needs_granting, get_grant_sql) %} + {% set dcl_statement_list = revoke_statement_list + provision_statement_list + grant_statement_list %} + {% if dcl_statement_list %} + {{ call_dcl_statements(dcl_statement_list) }} + {% endif %} + {% endif %} + {% endif %} +{% endmacro %} + +{% macro sqlserver__get_show_grant_sql(relation) %} + select + GRANTEE as grantee, + PRIVILEGE_TYPE as privilege_type + from INFORMATION_SCHEMA.TABLE_PRIVILEGES {{ information_schema_hints() }} + where TABLE_CATALOG = '{{ relation.database }}' + and TABLE_SCHEMA = '{{ relation.schema }}' + and TABLE_NAME = '{{ relation.identifier }}' +{% endmacro %} + +{%- macro sqlserver__get_grant_sql(relation, privilege, grantees) -%} + {%- set grantees_safe = [] -%} + {%- for grantee in grantees -%} + {%- set grantee_safe = "[" ~ grantee ~ "]" -%} + {%- do grantees_safe.append(grantee_safe) -%} + {%- endfor -%} + grant {{ privilege }} on {{ relation }} to {{ grantees_safe | join(', ') }} +{%- endmacro -%} + +{%- macro sqlserver__get_revoke_sql(relation, privilege, grantees) -%} + {%- set grantees_safe = [] -%} + {%- for grantee in grantees -%} + {%- set grantee_safe = "[" ~ grantee ~ "]" -%} + {%- do grantees_safe.append(grantee_safe) -%} + {%- endfor -%} + revoke {{ privilege }} on {{ relation }} from {{ grantees_safe | join(', ') }} +{%- endmacro -%} + +{% macro get_provision_sql(relation, privilege, grantees) %} + {% for grantee in grantees %} + if not exists(select name from sys.database_principals where name = '{{ grantee }}') + create user [{{ grantee }}] from external provider; + {% endfor %} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/catalog.sql b/dbt/include/sqlserver/macros/adapters/catalog.sql similarity index 100% rename from dbt/include/sqlserver/macros/adapter/catalog.sql rename to dbt/include/sqlserver/macros/adapters/catalog.sql diff --git a/dbt/include/sqlserver/macros/adapters/columns.sql b/dbt/include/sqlserver/macros/adapters/columns.sql new file mode 100644 index 00000000..f5efc6b4 --- /dev/null +++ b/dbt/include/sqlserver/macros/adapters/columns.sql @@ -0,0 +1,101 @@ +{% macro sqlserver__get_empty_subquery_sql(select_sql, select_sql_header=none) %} + {% if select_sql.strip().lower().startswith('with') %} + {{ select_sql }} + {% else -%} + select * from ( + {{ select_sql }} + ) dbt_sbq_tmp + where 1 = 0 + {%- endif -%} + +{% endmacro %} + +{% macro sqlserver__get_columns_in_query(select_sql) %} + {% set query_label = apply_label() %} + {% call statement('get_columns_in_query', fetch_result=True, auto_begin=False) -%} + select TOP 0 * from ( + {{ select_sql }} + ) as __dbt_sbq + where 0 = 1 + {{ query_label }} + {% endcall %} + + {{ return(load_result('get_columns_in_query').table.columns | map(attribute='name') | list) }} +{% endmacro %} + +{% macro sqlserver__alter_column_type(relation, column_name, new_column_type) %} + + {%- set tmp_column = column_name + "__dbt_alter" -%} + {% set alter_column_type %} + alter {{ relation.type }} {{ relation }} add "{{ tmp_column }}" {{ new_column_type }}; + {%- endset %} + + {% set update_column %} + update {{ relation }} set "{{ tmp_column }}" = "{{ column_name }}"; + {%- endset %} + + {% set drop_column %} + alter {{ relation.type }} {{ relation }} drop column "{{ column_name }}"; + {%- endset %} + + {% set rename_column %} + exec sp_rename '{{ relation | replace('"', '') }}.{{ tmp_column }}', '{{ column_name }}', 'column' + {%- endset %} + + {% do run_query(alter_column_type) %} + {% do run_query(update_column) %} + {% do run_query(drop_column) %} + {% do run_query(rename_column) %} + +{% endmacro %} + + +{% macro sqlserver__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} + {% call statement('add_drop_columns') -%} + {% if add_columns %} + alter {{ relation.type }} {{ relation }} + add {% for column in add_columns %}"{{ column.name }}" {{ column.data_type }}{{ ', ' if not loop.last }}{% endfor %}; + {% endif %} + + {% if remove_columns %} + alter {{ relation.type }} {{ relation }} + drop column {% for column in remove_columns %}"{{ column.name }}"{{ ',' if not loop.last }}{% endfor %}; + {% endif %} + {%- endcall -%} +{% endmacro %} + +{% macro sqlserver__get_columns_in_relation(relation) -%} + {% set query_label = apply_label() %} + {% call statement('get_columns_in_relation', fetch_result=True) %} + {{ get_use_database_sql(relation.database) }} + with mapping as ( + select + row_number() over (partition by object_name(c.object_id) order by c.column_id) as ordinal_position, + c.name collate database_default as column_name, + t.name as data_type, + case + when (t.name in ('nchar', 'nvarchar', 'sysname') and c.max_length <> -1) then c.max_length / 2 + else c.max_length + end as character_maximum_length, + c.precision as numeric_precision, + c.scale as numeric_scale + from sys.columns c {{ information_schema_hints() }} + inner join sys.types t {{ information_schema_hints() }} + on c.user_type_id = t.user_type_id + where c.object_id = object_id('{{ 'tempdb..' ~ relation.include(database=false, schema=false) if '#' in relation.identifier else relation }}') + ) + + select + column_name, + data_type, + character_maximum_length, + numeric_precision, + numeric_scale + from mapping + order by ordinal_position + {{ query_label }} + + {% endcall %} + {% set table = load_result('get_columns_in_relation').table %} + {{ return(sql_convert_columns_in_relation(table)) }} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/indexes.sql b/dbt/include/sqlserver/macros/adapters/indexes.sql similarity index 71% rename from dbt/include/sqlserver/macros/adapter/indexes.sql rename to dbt/include/sqlserver/macros/adapters/indexes.sql index 33fa6cfe..2fd3733a 100644 --- a/dbt/include/sqlserver/macros/adapter/indexes.sql +++ b/dbt/include/sqlserver/macros/adapters/indexes.sql @@ -168,3 +168,71 @@ {% endif %} end {% endmacro %} + + +{% macro drop_fk_indexes_on_table(relation) -%} + {% call statement('find_references', fetch_result=true) %} + USE [{{ relation.database }}]; + SELECT obj.name AS FK_NAME, + sch.name AS [schema_name], + tab1.name AS [table], + col1.name AS [column], + tab2.name AS [referenced_table], + col2.name AS [referenced_column] + FROM sys.foreign_key_columns fkc + INNER JOIN sys.objects obj + ON obj.object_id = fkc.constraint_object_id + INNER JOIN sys.tables tab1 + ON tab1.object_id = fkc.parent_object_id + INNER JOIN sys.schemas sch + ON tab1.schema_id = sch.schema_id + INNER JOIN sys.columns col1 + ON col1.column_id = parent_column_id AND col1.object_id = tab1.object_id + INNER JOIN sys.tables tab2 + ON tab2.object_id = fkc.referenced_object_id + INNER JOIN sys.columns col2 + ON col2.column_id = referenced_column_id AND col2.object_id = tab2.object_id + WHERE sch.name = '{{ relation.schema }}' and tab2.name = '{{ relation.identifier }}' + {% endcall %} + {% set references = load_result('find_references')['data'] %} + {% for reference in references -%} + {% call statement('main') -%} + alter table [{{reference[1]}}].[{{reference[2]}}] drop constraint [{{reference[0]}}] + {%- endcall %} + {% endfor %} +{% endmacro %} + +{% macro sqlserver__list_nonclustered_rowstore_indexes(relation) -%} + {% call statement('list_nonclustered_rowstore_indexes', fetch_result=True) -%} + + SELECT i.name AS index_name + , i.name + '__dbt_backup' as index_new_name + , COL_NAME(ic.object_id,ic.column_id) AS column_name + FROM sys.indexes AS i + INNER JOIN sys.index_columns AS ic + ON i.object_id = ic.object_id AND i.index_id = ic.index_id and i.type <> 5 + WHERE i.object_id = OBJECT_ID('{{ relation.schema }}.{{ relation.identifier }}') + + UNION ALL + + SELECT obj.name AS index_name + , obj.name + '__dbt_backup' as index_new_name + , col1.name AS column_name + FROM sys.foreign_key_columns fkc + INNER JOIN sys.objects obj + ON obj.object_id = fkc.constraint_object_id + INNER JOIN sys.tables tab1 + ON tab1.object_id = fkc.parent_object_id + INNER JOIN sys.schemas sch + ON tab1.schema_id = sch.schema_id + INNER JOIN sys.columns col1 + ON col1.column_id = parent_column_id AND col1.object_id = tab1.object_id + INNER JOIN sys.tables tab2 + ON tab2.object_id = fkc.referenced_object_id + INNER JOIN sys.columns col2 + ON col2.column_id = referenced_column_id AND col2.object_id = tab2.object_id + WHERE sch.name = '{{ relation.schema }}' and tab1.name = '{{ relation.identifier }}' + + {% endcall %} + {{ return(load_result('list_nonclustered_rowstore_indexes').table) }} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/metadata.sql b/dbt/include/sqlserver/macros/adapters/metadata.sql new file mode 100644 index 00000000..4d550fcd --- /dev/null +++ b/dbt/include/sqlserver/macros/adapters/metadata.sql @@ -0,0 +1,112 @@ +{% macro apply_label() %} + {{ log (config.get('query_tag','dbt-sqlserver'))}} + {%- set query_label = config.get('query_tag','dbt-sqlserver') -%} + OPTION (LABEL = '{{query_label}}'); +{% endmacro %} + +{% macro default__information_schema_hints() %}{% endmacro %} +{% macro sqlserver__information_schema_hints() %}with (nolock){% endmacro %} + +{% macro information_schema_hints() %} + {{ return(adapter.dispatch('information_schema_hints')()) }} +{% endmacro %} + +{% macro sqlserver__information_schema_name(database) -%} + information_schema +{%- endmacro %} + +{% macro get_use_database_sql(database) %} + {{ return(adapter.dispatch('get_use_database_sql', 'dbt')(database)) }} +{% endmacro %} + +{%- macro sqlserver__get_use_database_sql(database) -%} + USE [{{database | replace('"', '')}}]; +{%- endmacro -%} + +{% macro sqlserver__list_schemas(database) %} + {% call statement('list_schemas', fetch_result=True, auto_begin=False) -%} + {{ get_use_database_sql(database) }} + select name as [schema] + from sys.schemas {{ information_schema_hints() }} {{ apply_label() }} + {% endcall %} + {{ return(load_result('list_schemas').table) }} +{% endmacro %} + +{% macro sqlserver__check_schema_exists(information_schema, schema) -%} + {% call statement('check_schema_exists', fetch_result=True, auto_begin=False) -%} + SELECT count(*) as schema_exist FROM sys.schemas WHERE name = '{{ schema }}' {{ apply_label() }} + {%- endcall %} + {{ return(load_result('check_schema_exists').table) }} +{% endmacro %} + +{% macro sqlserver__list_relations_without_caching(schema_relation) -%} + {% call statement('list_relations_without_caching', fetch_result=True) -%} + {{ get_use_database_sql(schema_relation.database) }} + with base as ( + select + DB_NAME() as [database], + t.name as [name], + SCHEMA_NAME(t.schema_id) as [schema], + 'table' as table_type + from sys.tables as t {{ information_schema_hints() }} + union all + select + DB_NAME() as [database], + v.name as [name], + SCHEMA_NAME(v.schema_id) as [schema], + 'view' as table_type + from sys.views as v {{ information_schema_hints() }} + ) + select * from base + where [schema] like '{{ schema_relation.schema }}' + {{ apply_label() }} + {% endcall %} + {{ return(load_result('list_relations_without_caching').table) }} +{% endmacro %} + +{% macro sqlserver__get_relation_without_caching(schema_relation) -%} + {% call statement('get_relation_without_caching', fetch_result=True) -%} + {{ get_use_database_sql(schema_relation.database) }} + with base as ( + select + DB_NAME() as [database], + t.name as [name], + SCHEMA_NAME(t.schema_id) as [schema], + 'table' as table_type + from sys.tables as t {{ information_schema_hints() }} + union all + select + DB_NAME() as [database], + v.name as [name], + SCHEMA_NAME(v.schema_id) as [schema], + 'view' as table_type + from sys.views as v {{ information_schema_hints() }} + ) + select * from base + where [schema] like '{{ schema_relation.schema }}' + and [name] like '{{ schema_relation.identifier }}' + {{ apply_label() }} + {% endcall %} + {{ return(load_result('get_relation_without_caching').table) }} +{% endmacro %} + +{% macro sqlserver__get_relation_last_modified(information_schema, relations) -%} + {%- call statement('last_modified', fetch_result=True) -%} + select + o.name as [identifier] + , s.name as [schema] + , o.modify_date as last_modified + , current_timestamp as snapshotted_at + from sys.objects o + inner join sys.schemas s on o.schema_id = s.schema_id and [type] = 'U' + where ( + {%- for relation in relations -%} + (upper(s.name) = upper('{{ relation.schema }}') and + upper(o.name) = upper('{{ relation.identifier }}')){%- if not loop.last %} or {% endif -%} + {%- endfor -%} + ) + {{ apply_label() }} + {%- endcall -%} + {{ return(load_result('last_modified')) }} + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/relation.sql b/dbt/include/sqlserver/macros/adapters/relation.sql new file mode 100644 index 00000000..b6c4036d --- /dev/null +++ b/dbt/include/sqlserver/macros/adapters/relation.sql @@ -0,0 +1,55 @@ +{% macro sqlserver__make_temp_relation(base_relation, suffix='__dbt_temp') %} + {%- set temp_identifier = base_relation.identifier ~ suffix -%} + {%- set temp_relation = base_relation.incorporate( + path={"identifier": temp_identifier}) -%} + + {{ return(temp_relation) }} +{% endmacro %} + +{% macro sqlserver__get_drop_sql(relation) -%} + {% if relation.type == 'view' -%} + {% call statement('find_references', fetch_result=true) %} + {{ get_use_database_sql(relation.database) }} + select + sch.name as schema_name, + obj.name as view_name + from sys.sql_expression_dependencies refs + inner join sys.objects obj + on refs.referencing_id = obj.object_id + inner join sys.schemas sch + on obj.schema_id = sch.schema_id + where refs.referenced_database_name = '{{ relation.database }}' + and refs.referenced_schema_name = '{{ relation.schema }}' + and refs.referenced_entity_name = '{{ relation.identifier }}' + and obj.type = 'V' + {{ apply_label() }} + {% endcall %} + {% set references = load_result('find_references')['data'] %} + {% for reference in references -%} + -- dropping referenced view {{ reference[0] }}.{{ reference[1] }} + {% do adapter.drop_relation + (api.Relation.create( + identifier = reference[1], schema = reference[0], database = relation.database, type='view' + ))%} + {% endfor %} + {% elif relation.type == 'table'%} + {% set object_id_type = 'U' %} + {%- else -%} + {{ exceptions.raise_not_implemented('Invalid relation being dropped: ' ~ relation) }} + {% endif %} + {{ get_use_database_sql(relation.database) }} + EXEC('DROP {{ relation.type }} IF EXISTS {{ relation.include(database=False) }};'); +{% endmacro %} + +{% macro sqlserver__rename_relation(from_relation, to_relation) -%} + {% call statement('rename_relation') -%} + {{ get_use_database_sql(from_relation.database) }} + EXEC sp_rename '{{ from_relation.schema }}.{{ from_relation.identifier }}', '{{ to_relation.identifier }}' + {%- endcall %} +{% endmacro %} + +{% macro sqlserver__truncate_relation(relation) -%} + {% call statement('truncate_relation') -%} + truncate table {{ relation }} + {%- endcall %} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/schema.sql b/dbt/include/sqlserver/macros/adapters/schema.sql new file mode 100644 index 00000000..4fbef3e3 --- /dev/null +++ b/dbt/include/sqlserver/macros/adapters/schema.sql @@ -0,0 +1,42 @@ +{% macro sqlserver__create_schema(relation) -%} + {% call statement('create_schema') -%} + USE [{{ relation.database }}]; + IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ relation.schema }}') + BEGIN + EXEC('CREATE SCHEMA [{{ relation.schema }}]') + END + {% endcall %} +{% endmacro %} + +{% macro sqlserver__create_schema_with_authorization(relation, schema_authorization) -%} + {% call statement('create_schema') -%} + {{ get_use_database_sql(relation.database) }} + IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ relation.schema }}') + BEGIN + EXEC('CREATE SCHEMA [{{ relation.schema }}] AUTHORIZATION [{{ schema_authorization }}]') + END + {% endcall %} +{% endmacro %} + +{% macro sqlserver__drop_schema(relation) -%} + {%- set relations_in_schema = list_relations_without_caching(relation) %} + + {% for row in relations_in_schema %} + {%- set schema_relation = api.Relation.create(database=relation.database, + schema=relation.schema, + identifier=row[1], + type=row[3] + ) -%} + {% do adapter.drop_relation(schema_relation) %} + {%- endfor %} + + {% call statement('drop_schema') -%} + {{ get_use_database_sql(relation.database) }} + EXEC('DROP SCHEMA IF EXISTS {{ relation.schema }}') + {% endcall %} +{% endmacro %} + +{% macro sqlserver__drop_schema_named(schema_name) %} + {% set schema_relation = api.Relation.create(schema=schema_name) %} + {{ adapter.drop_schema(schema_relation) }} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapters/show.sql b/dbt/include/sqlserver/macros/adapters/show.sql new file mode 100644 index 00000000..14337a73 --- /dev/null +++ b/dbt/include/sqlserver/macros/adapters/show.sql @@ -0,0 +1,12 @@ +{% macro sqlserver__get_limit_sql(sql, limit) %} + {%- if limit == -1 or limit is none -%} + {{ sql }} + {#- Special processing if the last non-blank line starts with order by -#} + {%- elif sql.strip().splitlines()[-1].strip().lower().startswith('order by') -%} + {{ sql }} + offset 0 rows fetch first {{ limit }} rows only + {%- else -%} + {{ sql }} + order by (select null) offset 0 rows fetch first {{ limit }} rows only + {%- endif -%} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/adapter/validate_sql.sql b/dbt/include/sqlserver/macros/adapters/validate_sql.sql similarity index 100% rename from dbt/include/sqlserver/macros/adapter/validate_sql.sql rename to dbt/include/sqlserver/macros/adapters/validate_sql.sql diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql new file mode 100644 index 00000000..393d7020 --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental_strategies.sql @@ -0,0 +1,11 @@ +{% macro sqlserver__get_incremental_default_sql(arg_dict) %} + + {% if arg_dict["unique_key"] %} + -- Delete + Insert Strategy, calls get_delete_insert_merge_sql + {% do return(get_incremental_merge_sql(arg_dict)) %} + {% else %} + -- Incremental Append will insert data into target table. + {% do return(get_incremental_append_sql(arg_dict)) %} + {% endif %} + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql index 9d8cdc0f..7c325cd4 100644 --- a/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql +++ b/dbt/include/sqlserver/macros/materializations/models/incremental/merge.sql @@ -1,3 +1,57 @@ +{% macro sqlserver__get_merge_sql(target, source, unique_key, dest_columns, incremental_predicates=none) %} + {{ default__get_merge_sql(target, source, unique_key, dest_columns, incremental_predicates) }}; +{% endmacro %} + +{% macro sqlserver__get_insert_overwrite_merge_sql(target, source, dest_columns, predicates, include_sql_header) %} + {{ default__get_insert_overwrite_merge_sql(target, source, dest_columns, predicates, include_sql_header) }}; +{% endmacro %} + +{% macro sqlserver__get_delete_insert_merge_sql(target, source, unique_key, dest_columns, incremental_predicates=none) %} + + {% set query_label = apply_label() %} + {%- set dest_cols_csv = get_quoted_csv(dest_columns | map(attribute="name")) -%} + + {% if unique_key %} + {% if unique_key is sequence and unique_key is not string %} + delete from {{ target }} + where exists ( + select null + from {{ source }} + where + {% for key in unique_key %} + {{ source }}.{{ key }} = {{ target }}.{{ key }} + {{ "and " if not loop.last }} + {% endfor %} + ) + {% if incremental_predicates %} + {% for predicate in incremental_predicates %} + and {{ predicate }} + {% endfor %} + {% endif %} + {{ query_label }} + {% else %} + delete from {{ target }} + where ( + {{ unique_key }}) in ( + select ({{ unique_key }}) + from {{ source }} + ) + {%- if incremental_predicates %} + {% for predicate in incremental_predicates %} + and {{ predicate }} + {% endfor %} + {%- endif -%} + {{ query_label }} + {% endif %} + {% endif %} + + insert into {{ target }} ({{ dest_cols_csv }}) + ( + select {{ dest_cols_csv }} + from {{ source }} + ){{ query_label }} +{% endmacro %} + {% macro sqlserver__get_incremental_microbatch_sql(arg_dict) %} {%- set target = arg_dict["target_relation"] -%} {%- set source = arg_dict["temp_relation"] -%} diff --git a/dbt/include/sqlserver/macros/materializations/models/table/clone.sql b/dbt/include/sqlserver/macros/materializations/models/table/clone.sql new file mode 100644 index 00000000..a5981283 --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/table/clone.sql @@ -0,0 +1,4 @@ +{% macro sqlserver__create_or_replace_clone(target_relation, defer_relation) %} + CREATE TABLE {{target_relation}} + AS CLONE OF {{defer_relation}} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/table/columns_spec_ddl.sql b/dbt/include/sqlserver/macros/materializations/models/table/columns_spec_ddl.sql new file mode 100644 index 00000000..e545dbad --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/table/columns_spec_ddl.sql @@ -0,0 +1,30 @@ +{% macro build_columns_constraints(relation) %} + {{ return(adapter.dispatch('build_columns_constraints', 'dbt')(relation)) }} +{% endmacro %} + +{% macro sqlserver__build_columns_constraints(relation) %} + {# loop through user_provided_columns to create DDL with data types and constraints #} + {%- set raw_column_constraints = adapter.render_raw_columns_constraints(raw_columns=model['columns']) -%} + ( + {% for c in raw_column_constraints -%} + {{ c }}{{ "," if not loop.last }} + {% endfor %} + ) +{% endmacro %} + +{% macro build_model_constraints(relation) %} + {{ return(adapter.dispatch('build_model_constraints', 'dbt')(relation)) }} +{% endmacro %} + +{% macro sqlserver__build_model_constraints(relation) %} + {# loop through user_provided_columns to create DDL with data types and constraints #} + {%- set raw_model_constraints = adapter.render_raw_model_constraints(raw_constraints=model['constraints']) -%} + {% for c in raw_model_constraints -%} + {% set alter_table_script %} + alter table {{ relation.include(database=False) }} {{c}}; + {%endset%} + {% call statement('alter_table_add_constraint') -%} + {{alter_table_script}} + {%- endcall %} + {% endfor -%} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/unit_test/unit_test_create_table_as.sql b/dbt/include/sqlserver/macros/materializations/models/unit_test/unit_test_create_table_as.sql new file mode 100644 index 00000000..2919d470 --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/unit_test/unit_test_create_table_as.sql @@ -0,0 +1,58 @@ +{% macro check_for_nested_cte(sql) %} + {% if execute %} {# Ensure this runs only at execution time #} + {% set cleaned_sql = sql | lower | replace("\n", " ") %} {# Convert to lowercase and remove newlines #} + {% set cte_count = cleaned_sql.count("with ") %} {# Count occurrences of "WITH " #} + {% if cte_count > 1 %} + {{ return(True) }} + {% else %} + {{ return(False) }} {# No nested CTEs found #} + {% endif %} + {% else %} + {{ return(False) }} {# Return False during parsing #} + {% endif %} +{% endmacro %} + +{% macro sqlserver__unit_test_create_table_as(temporary, relation, sql) -%} + {% set query_label = apply_label() %} + {% set contract_config = config.get('contract') %} + {% set is_nested_cte = check_for_nested_cte(sql) %} + + {% if is_nested_cte %} + {{ exceptions.warn( + "Nested CTE warning: Nested CTE's do not support CTAS. However, 2-level nested CTEs are supported due to a code bug. Please expect this fix in the future." + ) }} + {% endif %} + + {% if is_nested_cte and contract_config.enforced %} + + {{ exceptions.raise_compiler_error( + "Unit test Materialization error: Since the contract is enforced and the model contains a nested CTE, unit tests cannot be materialized. Please refactor your model or unenforce model and try again." + ) }} + + {%- elif not is_nested_cte and contract_config.enforced %} + + CREATE TABLE {{relation}} + {{ build_columns_constraints(relation) }} + {{ get_assert_columns_equivalent(sql) }} + + {% set listColumns %} + {% for column in model['columns'] %} + {{ "["~column~"]" }}{{ ", " if not loop.last }} + {% endfor %} + {%endset%} + + {% set tmp_vw_relation = relation.incorporate(path={"identifier": relation.identifier ~ '__dbt_tmp_vw'}, type='view')-%} + {% do adapter.drop_relation(tmp_vw_relation) %} + {{ get_create_view_as_sql(tmp_vw_relation, sql) }} + + INSERT INTO {{relation}} ({{listColumns}}) + SELECT {{listColumns}} FROM {{tmp_vw_relation}} {{ query_label }} + + {%- else %} + + {%- set query_label_option = query_label.replace("'", "''") -%} + {%- set sql_with_quotes = sql.replace("'", "''") -%} + EXEC('CREATE TABLE {{relation}} AS {{sql_with_quotes}} {{ query_label_option }}'); + + {% endif %} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/view/create_view_as.sql b/dbt/include/sqlserver/macros/materializations/models/view/create_view_as.sql new file mode 100644 index 00000000..9066cb21 --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/models/view/create_view_as.sql @@ -0,0 +1,12 @@ +{% macro sqlserver__create_view_exec(relation, sql) -%} + + {%- set temp_view_sql = sql.replace("'", "''") -%} + {{ get_use_database_sql(relation.database) }} + {% set contract_config = config.get('contract') %} + {% if contract_config.enforced %} + {{ get_assert_columns_equivalent(sql) }} + {%- endif %} + + EXEC('create view {{ relation.include(database=False) }} as {{ temp_view_sql }};'); + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshot/helpers.sql b/dbt/include/sqlserver/macros/materializations/snapshot/helpers.sql deleted file mode 100644 index 3317a9f3..00000000 --- a/dbt/include/sqlserver/macros/materializations/snapshot/helpers.sql +++ /dev/null @@ -1,35 +0,0 @@ -{% macro sqlserver__create_columns(relation, columns) %} - {% set column_list %} - {% for column_entry in columns %} - {{column_entry.name}} {{column_entry.data_type}}{{ ", " if not loop.last }} - {% endfor %} - {% endset %} - - {% set alter_sql %} - ALTER TABLE {{ relation }} - ADD {{ column_list }} - {% endset %} - - {% set results = run_query(alter_sql) %} - -{% endmacro %} - -{% macro build_snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} - {% set temp_relation = make_temp_relation(target_relation) %} - {{ adapter.drop_relation(temp_relation) }} - - {% set select = snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} - - {% set tmp_tble_vw_relation = temp_relation.incorporate(path={"identifier": temp_relation.identifier ~ '__dbt_tmp_vw'}, type='view')-%} - -- Dropping temp view relation if it exists - {{ adapter.drop_relation(tmp_tble_vw_relation) }} - - {% call statement('build_snapshot_staging_relation') %} - {{ get_create_table_as_sql(True, temp_relation, select) }} - {% endcall %} - - -- Dropping temp view relation if it exists - {{ adapter.drop_relation(tmp_tble_vw_relation) }} - - {% do return(temp_relation) %} -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshot/snapshot_merge.sql b/dbt/include/sqlserver/macros/materializations/snapshot/snapshot_merge.sql deleted file mode 100644 index caae7740..00000000 --- a/dbt/include/sqlserver/macros/materializations/snapshot/snapshot_merge.sql +++ /dev/null @@ -1,19 +0,0 @@ -{% macro sqlserver__snapshot_merge_sql(target, source, insert_cols) -%} - {%- set insert_cols_csv = insert_cols | join(', ') -%} - - merge into {{ target.render() }} as DBT_INTERNAL_DEST - using {{ source }} as DBT_INTERNAL_SOURCE - on DBT_INTERNAL_SOURCE.dbt_scd_id = DBT_INTERNAL_DEST.dbt_scd_id - - when matched - and DBT_INTERNAL_DEST.dbt_valid_to is null - and DBT_INTERNAL_SOURCE.dbt_change_type in ('update', 'delete') - then update - set dbt_valid_to = DBT_INTERNAL_SOURCE.dbt_valid_to - - when not matched - and DBT_INTERNAL_SOURCE.dbt_change_type = 'insert' - then insert ({{ insert_cols_csv }}) - values ({{ insert_cols_csv }}) - ; -{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/helpers.sql b/dbt/include/sqlserver/macros/materializations/snapshots/helpers.sql new file mode 100644 index 00000000..0f1e908b --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/snapshots/helpers.sql @@ -0,0 +1,189 @@ +{% macro sqlserver__create_columns(relation, columns) %} + {% set column_list %} + {% for column_entry in columns %} + {{column_entry.name}} {{column_entry.data_type}}{{ ", " if not loop.last }} + {% endfor %} + {% endset %} + + {% set alter_sql %} + ALTER TABLE {{ relation }} + ADD {{ column_list }} + {% endset %} + + {% set results = run_query(alter_sql) %} + +{% endmacro %} + +{% macro build_snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} + {% set temp_relation = make_temp_relation(target_relation) %} + {{ adapter.drop_relation(temp_relation) }} + + {% set select = snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} + + {% set tmp_tble_vw_relation = temp_relation.incorporate(path={"identifier": temp_relation.identifier ~ '__dbt_tmp_vw'}, type='view')-%} + -- Dropping temp view relation if it exists + {{ adapter.drop_relation(tmp_tble_vw_relation) }} + + {% call statement('build_snapshot_staging_relation') %} + {{ get_create_table_as_sql(True, temp_relation, select) }} + {% endcall %} + + -- Dropping temp view relation if it exists + {{ adapter.drop_relation(tmp_tble_vw_relation) }} + + {% do return(temp_relation) %} +{% endmacro %} + + +{% macro sqlserver__post_snapshot(staging_relation) %} + -- Clean up the snapshot temp table + {% do drop_relation_if_exists(staging_relation) %} +{% endmacro %} + +{% macro sqlserver__get_true_sql() %} + {{ return('1=1') }} +{% endmacro %} + +{% macro sqlserver__build_snapshot_table(strategy, relation) %} + {% set columns = config.get('snapshot_table_column_names') or get_snapshot_table_column_names() %} + select *, + {{ strategy.scd_id }} as {{ columns.dbt_scd_id }}, + {{ strategy.updated_at }} as {{ columns.dbt_updated_at }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_from }}, + {{ get_dbt_valid_to_current(strategy, columns) }} + {%- if strategy.hard_deletes == 'new_record' -%} + , 'False' as {{ columns.dbt_is_deleted }} + {% endif -%} + from ( + select * from {{ relation }} + ) sbq + +{% endmacro %} + +{% macro sqlserver__snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) -%} + + {% set columns = config.get('snapshot_table_column_names') or get_snapshot_table_column_names() %} + + with snapshot_query as ( + select * from {{ temp_snapshot_relation }} + ), + snapshotted_data as ( + select *, + {{ unique_key_fields(strategy.unique_key) }} + from {{ target_relation }} + where + {% if config.get('dbt_valid_to_current') %} + {# Check for either dbt_valid_to_current OR null, in order to correctly update records with nulls #} + ( {{ columns.dbt_valid_to }} = {{ config.get('dbt_valid_to_current') }} or {{ columns.dbt_valid_to }} is null) + {% else %} + {{ columns.dbt_valid_to }} is null + {% endif %} + {%- if strategy.hard_deletes == 'new_record' -%} + and {{ columns.dbt_is_deleted }} = 'False' + {% endif -%} + ), + insertions_source_data as ( + select *, + {{ unique_key_fields(strategy.unique_key) }}, + {{ strategy.updated_at }} as {{ columns.dbt_updated_at }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_from }}, + {{ get_dbt_valid_to_current(strategy, columns) }}, + {{ strategy.scd_id }} as {{ columns.dbt_scd_id }} + from snapshot_query + ), + updates_source_data as ( + select *, + {{ unique_key_fields(strategy.unique_key) }}, + {{ strategy.updated_at }} as {{ columns.dbt_updated_at }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_from }}, + {{ strategy.updated_at }} as {{ columns.dbt_valid_to }} + from snapshot_query + ), + {%- if strategy.hard_deletes == 'invalidate' or strategy.hard_deletes == 'new_record' %} + deletes_source_data as ( + select *, {{ unique_key_fields(strategy.unique_key) }} + from snapshot_query + ), + {% endif %} + insertions as ( + select 'insert' as dbt_change_type, source_data.* + {%- if strategy.hard_deletes == 'new_record' -%} + ,'False' as {{ columns.dbt_is_deleted }} + {%- endif %} + from insertions_source_data as source_data + left outer join snapshotted_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where {{ unique_key_is_null(strategy.unique_key, "snapshotted_data") }} + or ({{ unique_key_is_not_null(strategy.unique_key, "snapshotted_data") }} and ({{ strategy.row_changed }})) + ), + updates as ( + select 'update' as dbt_change_type, source_data.*, + snapshotted_data.{{ columns.dbt_scd_id }} + {%- if strategy.hard_deletes == 'new_record' -%} + , snapshotted_data.{{ columns.dbt_is_deleted }} + {%- endif %} + from updates_source_data as source_data + join snapshotted_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where ({{ strategy.row_changed }}) + ) + {%- if strategy.hard_deletes == 'invalidate' or strategy.hard_deletes == 'new_record' %} + , + deletes as ( + select 'delete' as dbt_change_type, + source_data.*, + {{ snapshot_get_time() }} as {{ columns.dbt_valid_from }}, + {{ snapshot_get_time() }} as {{ columns.dbt_updated_at }}, + {{ snapshot_get_time() }} as {{ columns.dbt_valid_to }}, + snapshotted_data.{{ columns.dbt_scd_id }} + {%- if strategy.hard_deletes == 'new_record' -%} + , snapshotted_data.{{ columns.dbt_is_deleted }} + {%- endif %} + from snapshotted_data + left join deletes_source_data as source_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where {{ unique_key_is_null(strategy.unique_key, "source_data") }} + ) + {%- endif %} + {%- if strategy.hard_deletes == 'new_record' %} + {%set source_query = "select * from "~temp_snapshot_relation%} + {% set source_sql_cols = get_column_schema_from_query(source_query) %} + , + deletion_records as ( + + select + 'insert' as dbt_change_type, + {%- for col in source_sql_cols -%} + snapshotted_data.{{ adapter.quote(col.column) }}, + {% endfor -%} + {%- if strategy.unique_key | is_list -%} + {%- for key in strategy.unique_key -%} + snapshotted_data.{{ key }} as dbt_unique_key_{{ loop.index }}, + {% endfor -%} + {%- else -%} + snapshotted_data.dbt_unique_key as dbt_unique_key, + {% endif -%} + {{ snapshot_get_time() }} as {{ columns.dbt_valid_from }}, + {{ snapshot_get_time() }} as {{ columns.dbt_updated_at }}, + snapshotted_data.{{ columns.dbt_valid_to }} as {{ columns.dbt_valid_to }}, + snapshotted_data.{{ columns.dbt_scd_id }}, + 'True' as {{ columns.dbt_is_deleted }} + from snapshotted_data + left join deletes_source_data as source_data + on {{ unique_key_join_on(strategy.unique_key, "snapshotted_data", "source_data") }} + where {{ unique_key_is_null(strategy.unique_key, "source_data") }} + ) + {%- endif %} + select * from insertions + union all + select * from updates + {%- if strategy.hard_deletes == 'invalidate' or strategy.hard_deletes == 'new_record' %} + union all + select * from deletes + {%- endif %} + {%- if strategy.hard_deletes == 'new_record' %} + union all + select * from deletion_records + {%- endif %} + +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshot/snapshot.sql b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql similarity index 63% rename from dbt/include/sqlserver/macros/materializations/snapshot/snapshot.sql rename to dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql index 5c9ddff3..23738453 100644 --- a/dbt/include/sqlserver/macros/materializations/snapshot/snapshot.sql +++ b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql @@ -1,8 +1,7 @@ {% materialization snapshot, adapter='sqlserver' %} - {%- set config = model['config'] -%} + {%- set config = model['config'] -%} {%- set target_table = model.get('alias', model.get('name')) -%} - {%- set strategy_name = config.get('strategy') -%} {%- set unique_key = config.get('unique_key') %} -- grab current tables grants config for comparision later on @@ -18,9 +17,7 @@ {% do exceptions.relation_wrong_type(target_relation, 'table') %} {%- endif -%} - {{ run_hooks(pre_hooks, inside_transaction=False) }} - {{ run_hooks(pre_hooks, inside_transaction=True) }} {% set strategy_macro = strategy_dispatch(strategy_name) %} @@ -30,69 +27,73 @@ database=model.database, schema=model.schema, identifier=target_table+"_snapshot_staging_temp_view", - type='view') - -%} - - {% set temp_snapshot_relation_sql = model['compiled_code'].replace("'", "''") %} - {% call statement('create temp_snapshot_relation') %} - USE [{{ model.database}}]; - EXEC('DROP VIEW IF EXISTS {{ temp_snapshot_relation.include(database=False) }};'); - EXEC('create view {{ temp_snapshot_relation.include(database=False) }} as {{ temp_snapshot_relation_sql }};'); - {% endcall %} + type='view') -%} + + -- Create a temporary view to manage if user SQl uses CTE + {% set temp_snapshot_relation_sql = model['compiled_code'] %} + {{ adapter.drop_relation(temp_snapshot_relation) }} + + {% call statement('create temp_snapshot_relation') -%} + {{ get_create_view_as_sql(temp_snapshot_relation, temp_snapshot_relation_sql) }} + {%- endcall %} {% if not target_relation_exists %} {% set build_sql = build_snapshot_table(strategy, temp_snapshot_relation) %} - {% set final_sql = create_table_as(False, target_relation, build_sql) %} + {% set build_or_select_sql = build_sql %} - {% else %} + -- naming a temp relation + {% set tmp_relation_view = target_relation.incorporate(path={"identifier": target_relation.identifier ~ '__dbt_tmp_vw'}, type='view')-%} + -- SQL server adapter uses temp relation because of lack of CTE support for CTE in CTAS, Insert + -- drop temp relation if exists + {{ adapter.drop_relation(tmp_relation_view) }} + {% set final_sql = get_create_table_as_sql(False, target_relation, build_sql) %} + {{ adapter.drop_relation(tmp_relation_view) }} - {{ adapter.valid_snapshot_target(target_relation) }} + {% else %} + {% set columns = config.get("snapshot_meta_column_names") or get_snapshot_table_column_names() %} + {{ adapter.valid_snapshot_target(target_relation, columns) }} + {% set build_or_select_sql = snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} {% set staging_table = build_snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} - -- this may no-op if the database does not require column expansion {% do adapter.expand_target_column_types(from_relation=staging_table, to_relation=target_relation) %} + {% set remove_columns = ['dbt_change_type', 'DBT_CHANGE_TYPE', 'dbt_unique_key', 'DBT_UNIQUE_KEY'] %} + {% if unique_key | is_list %} + {% for key in strategy.unique_key %} + {{ remove_columns.append('dbt_unique_key_' + loop.index|string) }} + {{ remove_columns.append('DBT_UNIQUE_KEY_' + loop.index|string) }} + {% endfor %} + {% endif %} {% set missing_columns = adapter.get_missing_columns(staging_table, target_relation) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} - {% if missing_columns|length > 0 %} + {% if missing_columns|length > 0 %} {{log("Missing columns length is: "~ missing_columns|length)}} {% do create_columns(target_relation, missing_columns) %} {% endif %} - {% set source_columns = adapter.get_columns_in_relation(staging_table) - | rejectattr('name', 'equalto', 'dbt_change_type') - | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') - | rejectattr('name', 'equalto', 'dbt_unique_key') - | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') + | rejectattr('name', 'in', remove_columns) | list %} - {% set quoted_source_columns = [] %} {% for column in source_columns %} {% do quoted_source_columns.append(adapter.quote(column.name)) %} {% endfor %} - {% set final_sql = snapshot_merge_sql( target = target_relation, source = staging_table, insert_cols = quoted_source_columns ) %} - {% endif %} - + {{ check_time_data_types(build_or_select_sql) }} {% call statement('main') %} {{ final_sql }} {% endcall %} {{ adapter.drop_relation(temp_snapshot_relation) }} - {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode=False) %} {% do apply_grants(target_relation, grant_config, should_revoke=should_revoke) %} @@ -103,7 +104,6 @@ {% endif %} {{ run_hooks(post_hooks, inside_transaction=True) }} - {{ adapter.commit() }} {% if staging_table is defined %} @@ -111,7 +111,6 @@ {% endif %} {{ run_hooks(post_hooks, inside_transaction=False) }} - {{ return({'relations': [target_relation]}) }} {% endmaterialization %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot_merge.sql b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot_merge.sql new file mode 100644 index 00000000..789fbea3 --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot_merge.sql @@ -0,0 +1,30 @@ +{% macro sqlserver__snapshot_merge_sql(target, source, insert_cols) %} + + {%- set insert_cols_csv = insert_cols | join(', ') -%} + {%- set columns = config.get("snapshot_table_column_names") or get_snapshot_table_column_names() -%} + {%- set target_table = target.include(database=False) -%} + {%- set source_table = source.include(database=False) -%} + {% set target_columns_list = [] %} + {% for column in insert_cols %} + {% set target_columns_list = target_columns_list.append("DBT_INTERNAL_SOURCE."+column) %} + {% endfor %} + {%- set target_columns = target_columns_list | join(', ') -%} + + update DBT_INTERNAL_DEST + set {{ columns.dbt_valid_to }} = DBT_INTERNAL_SOURCE.{{ columns.dbt_valid_to }} + from {{ target_table }} as DBT_INTERNAL_DEST + inner join {{ source_table }} as DBT_INTERNAL_SOURCE + on DBT_INTERNAL_SOURCE.{{ columns.dbt_scd_id }} = DBT_INTERNAL_DEST.{{ columns.dbt_scd_id }} + where DBT_INTERNAL_SOURCE.dbt_change_type in ('update', 'delete') + {% if config.get("dbt_valid_to_current") %} + and (DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} = {{ config.get('dbt_valid_to_current') }} or DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} is null) + {% else %} + and DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} is null + {% endif %} + {{ apply_label() }} + + insert into {{ target_table }} ({{ insert_cols_csv }}) + select {{target_columns}} from {{ source_table }} as DBT_INTERNAL_SOURCE + where DBT_INTERNAL_SOURCE.dbt_change_type = 'insert' + {{ apply_label() }} +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/strategies.sql b/dbt/include/sqlserver/macros/materializations/snapshots/strategies.sql new file mode 100644 index 00000000..6a316c6f --- /dev/null +++ b/dbt/include/sqlserver/macros/materializations/snapshots/strategies.sql @@ -0,0 +1,5 @@ +{% macro sqlserver__snapshot_hash_arguments(args) %} + CONVERT(VARCHAR(32), HashBytes('MD5', {% for arg in args %} + coalesce(cast({{ arg }} as varchar(8000)), '') {% if not loop.last %} + '|' + {% endif %} + {% endfor %}), 2) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/unit_tests.sql b/dbt/include/sqlserver/macros/materializations/unit_tests.sql index 30a2a06b..c8b2675b 100644 --- a/dbt/include/sqlserver/macros/materializations/unit_tests.sql +++ b/dbt/include/sqlserver/macros/materializations/unit_tests.sql @@ -1,47 +1,47 @@ {% macro sqlserver__get_unit_test_sql(main_sql, expected_fixture_sql, expected_column_names) -%} -USE [{{ target.database }}]; -IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ target.schema }}') -BEGIN -EXEC('CREATE SCHEMA [{{ target.schema }}]') -END + USE [{{ target.database }}]; + IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{{ target.schema }}') + BEGIN + EXEC('CREATE SCHEMA [{{ target.schema }}]') + END -{% set test_view %} - [{{ target.schema }}].[testview_{{ local_md5(main_sql) }}_{{ range(1300, 19000) | random }}] -{% endset %} -{% set test_sql = main_sql.replace("'", "''")%} -EXEC('create view {{test_view}} as {{ test_sql }};') + {% set test_view %} + [{{ target.schema }}].[testview_{{ local_md5(main_sql) }}_{{ range(1300, 19000) | random }}] + {% endset %} + {% set test_sql = main_sql.replace("'", "''")%} + EXEC('create view {{test_view}} as {{ test_sql }};') -{% set expected_view %} - [{{ target.schema }}].[expectedview_{{ local_md5(expected_fixture_sql) }}_{{ range(1300, 19000) | random }}] -{% endset %} -{% set expected_sql = expected_fixture_sql.replace("'", "''")%} -EXEC('create view {{expected_view}} as {{ expected_sql }};') + {% set expected_view %} + [{{ target.schema }}].[expectedview_{{ local_md5(expected_fixture_sql) }}_{{ range(1300, 19000) | random }}] + {% endset %} + {% set expected_sql = expected_fixture_sql.replace("'", "''")%} + EXEC('create view {{expected_view}} as {{ expected_sql }};') --- Build actual result given inputs -{% set unittest_sql %} -with dbt_internal_unit_test_actual as ( - select - {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as {{ adapter.quote("actual_or_expected") }} - from - {{ test_view }} -), --- Build expected result -dbt_internal_unit_test_expected as ( - select - {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as {{ adapter.quote("actual_or_expected") }} - from - {{ expected_view }} -) --- Union actual and expected results -select * from dbt_internal_unit_test_actual -union all -select * from dbt_internal_unit_test_expected -{% endset %} + -- Build actual result given inputs + {% set unittest_sql %} + with dbt_internal_unit_test_actual as ( + select + {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%},{% endif %}{%- endfor -%}, {{ dbt.string_literal("actual") }} as {{ adapter.quote("actual_or_expected") }} + from + {{ test_view }} + ), + -- Build expected result + dbt_internal_unit_test_expected as ( + select + {% for expected_column_name in expected_column_names %}{{expected_column_name}}{% if not loop.last -%}, {% endif %}{%- endfor -%}, {{ dbt.string_literal("expected") }} as {{ adapter.quote("actual_or_expected") }} + from + {{ expected_view }} + ) + -- Union actual and expected results + select * from dbt_internal_unit_test_actual + union all + select * from dbt_internal_unit_test_expected + {% endset %} -EXEC('{{- escape_single_quotes(unittest_sql) -}}') + EXEC('{{- escape_single_quotes(unittest_sql) -}}') -EXEC('drop view {{test_view}};') -EXEC('drop view {{expected_view}};') + EXEC('drop view {{test_view}};') + EXEC('drop view {{expected_view}};') {%- endmacro %} diff --git a/dbt/include/sqlserver/macros/readme.md b/dbt/include/sqlserver/macros/readme.md deleted file mode 100644 index 26fc23c2..00000000 --- a/dbt/include/sqlserver/macros/readme.md +++ /dev/null @@ -1,54 +0,0 @@ -# Alterations from Fabric - -## `materialization incremental` - -This is reset to the original logic from the global project. - -## `materialization view` - -This is reset to the original logic from the global project - -## `materialization table` - -This is resets to the original logic from the global project - -## `sqlserver__create_columns` - -SQLServer supports ALTER; this updates the logic to apply alter instead of the drop/recreate - -## `sqlserver__alter_column_type` - -SQLServer supports ALTER; this updates the logic to apply alter instead of the drop/recreate - - -## `sqlserver__can_clone_table` - -SQLServer cannot clone, so this just returns False - -## `sqlserver__create_table_as` - -Logic is slightly re-written from original. -There is an underlying issue with the structure in that its embedding in EXEC calls. - -This creates an issue where temporary tables cannot be used, as they dont exist within the context of the EXEC call. - -One work around might be to issue the create table from a `{{ run_query }}` statement in order to have it accessible outside the exec context. - -Additionally the expected {% do adapter.drop_relation(tmp_relation) %} does not fire. Possible cache issue? -Resolved by calling `DROP VIEW IF EXISTS` on the relation - -## `sqlserver__create_view_as` - -Updated to remove `create_view_as_exec` call. - -## `listagg` - -DBT expects a limit function, but the sqlserver syntax does not support it. Fabric also does not implement this properly - -## `sqlserver__snapshot_merge_sql` - -Restores logic to the merge statement logic like the dbt core. Merge will probably be slower then the existing logic - -## unit tests - -To accomidate the nested CTE situation, we create a temp view for the actual/expected and use those both in the test. diff --git a/dbt/include/sqlserver/macros/utils/any_value.sql b/dbt/include/sqlserver/macros/utils/any_value.sql new file mode 100644 index 00000000..6dcf8ec2 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/any_value.sql @@ -0,0 +1,5 @@ +{% macro sqlserver__any_value(expression) -%} + + min({{ expression }}) + +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/array_construct.sql b/dbt/include/sqlserver/macros/utils/array_construct.sql new file mode 100644 index 00000000..5088c9ac --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/array_construct.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__array_construct(inputs, data_type) -%} + JSON_ARRAY({{ inputs|join(' , ') }}) +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql b/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql new file mode 100644 index 00000000..9771afbf --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/cast_bool_to_text.sql @@ -0,0 +1,7 @@ +{% macro sqlserver__cast_bool_to_text(field) %} + case {{ field }} + when 1 then 'true' + when 0 then 'false' + else null + end +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/concat.sql b/dbt/include/sqlserver/macros/utils/concat.sql new file mode 100644 index 00000000..1b7c1755 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/concat.sql @@ -0,0 +1,7 @@ +{% macro sqlserver__concat(fields) -%} + {%- if fields|length < 2 -%} + {{ fields[0] }} + {%- else -%} + concat({{ fields|join(', ') }}) + {%- endif -%} +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/date_trunc.sql b/dbt/include/sqlserver/macros/utils/date_trunc.sql new file mode 100644 index 00000000..85b4ce32 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/date_trunc.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__date_trunc(datepart, date) %} + CAST(DATEADD({{datepart}}, DATEDIFF({{datepart}}, 0, {{date}}), 0) AS DATE) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/dateadd.sql b/dbt/include/sqlserver/macros/utils/dateadd.sql new file mode 100644 index 00000000..f3b24fa6 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/dateadd.sql @@ -0,0 +1,9 @@ +{% macro sqlserver__dateadd(datepart, interval, from_date_or_timestamp) %} + + dateadd( + {{ datepart }}, + {{ interval }}, + cast({{ from_date_or_timestamp }} as datetime2(6)) + ) + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/get_tables_by_pattern.sql b/dbt/include/sqlserver/macros/utils/get_tables_by_pattern.sql new file mode 100644 index 00000000..75d6b500 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/get_tables_by_pattern.sql @@ -0,0 +1,12 @@ +{% macro sqlserver__get_tables_by_pattern_sql(schema_pattern, table_pattern, exclude='', database=target.database) %} + + select distinct + table_schema as {{ adapter.quote('table_schema') }}, + table_name as {{ adapter.quote('table_name') }}, + {{ dbt_utils.get_table_types_sql() }} + from {{ database }}.INFORMATION_SCHEMA.TABLES + where table_schema like '{{ schema_pattern }}' + and table_name like '{{ table_pattern }}' + and table_name not like '{{ exclude }}' + +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/hash.sql b/dbt/include/sqlserver/macros/utils/hash.sql new file mode 100644 index 00000000..d965f81f --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/hash.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__hash(field) %} + lower(convert(varchar(50), hashbytes('md5', coalesce(convert(varchar(8000), {{field}}), '')), 2)) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/last_day.sql b/dbt/include/sqlserver/macros/utils/last_day.sql new file mode 100644 index 00000000..c523d944 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/last_day.sql @@ -0,0 +1,13 @@ +{% macro sqlserver__last_day(date, datepart) -%} + + {%- if datepart == 'quarter' -%} + CAST(DATEADD(QUARTER, DATEDIFF(QUARTER, 0, {{ date }}) + 1, -1) AS DATE) + {%- elif datepart == 'month' -%} + EOMONTH ( {{ date }}) + {%- elif datepart == 'year' -%} + CAST(DATEADD(YEAR, DATEDIFF(year, 0, {{ date }}) + 1, -1) AS DATE) + {%- else -%} + {{dbt_utils.default_last_day(date, datepart)}} + {%- endif -%} + +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/length.sql b/dbt/include/sqlserver/macros/utils/length.sql new file mode 100644 index 00000000..ee9431ac --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/length.sql @@ -0,0 +1,5 @@ +{% macro sqlserver__length(expression) %} + + len( {{ expression }} ) + +{%- endmacro -%} diff --git a/dbt/include/sqlserver/macros/utils/listagg.sql b/dbt/include/sqlserver/macros/utils/listagg.sql new file mode 100644 index 00000000..4d6ab215 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/listagg.sql @@ -0,0 +1,8 @@ +{% macro sqlserver__listagg(measure, delimiter_text, order_by_clause, limit_num) -%} + + string_agg({{ measure }}, {{ delimiter_text }}) + {%- if order_by_clause != None %} + within group ({{ order_by_clause }}) + {%- endif %} + +{%- endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/position.sql b/dbt/include/sqlserver/macros/utils/position.sql new file mode 100644 index 00000000..bd3f6577 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/position.sql @@ -0,0 +1,8 @@ +{% macro sqlserver__position(substring_text, string_text) %} + + CHARINDEX( + {{ substring_text }}, + {{ string_text }} + ) + +{%- endmacro -%} diff --git a/dbt/include/sqlserver/macros/utils/safe_cast.sql b/dbt/include/sqlserver/macros/utils/safe_cast.sql new file mode 100644 index 00000000..4ae065a7 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/safe_cast.sql @@ -0,0 +1,3 @@ +{% macro sqlserver__safe_cast(field, type) %} + try_cast({{field}} as {{type}}) +{% endmacro %} diff --git a/dbt/include/sqlserver/macros/utils/timestamps.sql b/dbt/include/sqlserver/macros/utils/timestamps.sql new file mode 100644 index 00000000..31795764 --- /dev/null +++ b/dbt/include/sqlserver/macros/utils/timestamps.sql @@ -0,0 +1,8 @@ +{% macro sqlserver__current_timestamp() -%} + CAST(SYSDATETIME() AS DATETIME2(6)) +{%- endmacro %} + +{% macro sqlserver__snapshot_string_as_time(timestamp) -%} + {%- set result = "CONVERT(DATETIME2(6), '" ~ timestamp ~ "')" -%} + {{ return(result) }} +{%- endmacro %} diff --git a/dev_requirements.txt b/dev_requirements.txt index 2a3c4c4a..239825e6 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -13,11 +13,12 @@ pip-tools pre-commit pytest pytest-dotenv -pytest-logbook pytest-csv pytest-xdist pytz tox>=3.13 twine wheel +pyodbc +azure-identity -e . diff --git a/setup.py b/setup.py index 0a63ce6d..6f89b0b8 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,6 @@ def run(self): packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[ - "dbt-fabric==1.9.3", "dbt-core>=1.9.0,<2.0", "dbt-common>=1.0,<2.0", "dbt-adapters>=1.11.0,<2.0", @@ -86,6 +85,7 @@ def run(self): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ], project_urls={ "Setup & configuration": "https://docs.getdbt.com/reference/warehouse-profiles/mssql-profile", # noqa: E501 diff --git a/tests/functional/adapter/dbt/test_changing_relation_type.py b/tests/functional/adapter/dbt/test_changing_relation_type.py new file mode 100644 index 00000000..aaa43fa0 --- /dev/null +++ b/tests/functional/adapter/dbt/test_changing_relation_type.py @@ -0,0 +1,5 @@ +from dbt.tests.adapter.relations.test_changing_relation_type import BaseChangeRelationTypeValidator + + +class TestChangeRelationTypesSQLServer(BaseChangeRelationTypeValidator): + pass diff --git a/tests/functional/adapter/dbt/test_data_types.py b/tests/functional/adapter/dbt/test_data_types.py new file mode 100644 index 00000000..6e7c679d --- /dev/null +++ b/tests/functional/adapter/dbt/test_data_types.py @@ -0,0 +1,60 @@ +import pytest +from dbt.tests.adapter.utils.data_types.test_type_bigint import BaseTypeBigInt +from dbt.tests.adapter.utils.data_types.test_type_boolean import BaseTypeBoolean +from dbt.tests.adapter.utils.data_types.test_type_float import BaseTypeFloat +from dbt.tests.adapter.utils.data_types.test_type_int import BaseTypeInt +from dbt.tests.adapter.utils.data_types.test_type_numeric import BaseTypeNumeric +from dbt.tests.adapter.utils.data_types.test_type_string import BaseTypeString +from dbt.tests.adapter.utils.data_types.test_type_timestamp import ( + BaseTypeTimestamp, + seeds__expected_csv, +) + + +@pytest.mark.skip(reason="SQL Server shows 'numeric' if you don't explicitly cast it to bigint") +class TestTypeBigIntSQLServer(BaseTypeBigInt): + pass + + +class TestTypeFloatSQLServer(BaseTypeFloat): + pass + + +class TestTypeIntSQLServer(BaseTypeInt): + pass + + +class TestTypeNumericSQLServer(BaseTypeNumeric): + pass + + +# Tests failing different than expected result +class TestTypeStringSQLServer(BaseTypeString): + def assert_columns_equal(self, project, expected_cols, actual_cols): + # ignore the size of the varchar since we do + # an optimization to not use varchar(max) all the time + assert ( + expected_cols[:-1] == actual_cols[:-1] + ), f"Type difference detected: {expected_cols} vs. {actual_cols}" + + +class TestTypeTimestampSQLServer(BaseTypeTimestamp): + @pytest.fixture(scope="class") + def seeds(self): + seeds__expected_yml = """ +version: 2 +seeds: + - name: expected + config: + column_types: + timestamp_col: "datetime2(6)" + """ + + return { + "expected.csv": seeds__expected_csv, + "expected.yml": seeds__expected_yml, + } + + +class TestTypeBooleanSQLServer(BaseTypeBoolean): + pass diff --git a/tests/functional/adapter/dbt/test_dbt_show.py b/tests/functional/adapter/dbt/test_dbt_show.py new file mode 100644 index 00000000..dca94b46 --- /dev/null +++ b/tests/functional/adapter/dbt/test_dbt_show.py @@ -0,0 +1,52 @@ +import pytest +from dbt.tests.adapter.dbt_show.fixtures import ( + models__ephemeral_model, + models__sample_model, + models__second_ephemeral_model, + seeds__sample_seed, +) +from dbt.tests.adapter.dbt_show.test_dbt_show import BaseShowSqlHeader +from dbt.tests.util import run_dbt + + +# -- Base classes for tests are imported based on whether the adapter supports dbt show -- +class BaseShowLimit: + @pytest.fixture(scope="class") + def models(self): + return { + "sample_model.sql": models__sample_model, + "ephemeral_model.sql": models__ephemeral_model, + } + + @pytest.fixture(scope="class") + def seeds(self): + return {"sample_seed.csv": seeds__sample_seed} + + @pytest.mark.parametrize( + "args,expected", + [ + ([], 5), # default limit + (["--limit", 3], 3), # fetch 3 rows + (["--limit", -1], 7), # fetch all rows + ], + ) + def test_limit(self, project, args, expected): + run_dbt(["build"]) + dbt_args = ["show", "--inline", models__second_ephemeral_model, *args] + results = run_dbt(dbt_args) + assert len(results.results[0].agate_table) == expected + # ensure limit was injected in compiled_code when limit specified in command args + limit = results.args.get("limit") + if limit > 0: + assert ( + f"offset 0 rows fetch first { limit } rows only" + in results.results[0].node.compiled_code + ) + + +class TestSQLServerShowLimit(BaseShowLimit): + pass + + +class TestSQLServerShowSqlHeader(BaseShowSqlHeader): + pass diff --git a/tests/functional/adapter/dbt/test_equals.py b/tests/functional/adapter/dbt/test_equals.py new file mode 100644 index 00000000..96067a86 --- /dev/null +++ b/tests/functional/adapter/dbt/test_equals.py @@ -0,0 +1,5 @@ +from dbt.tests.adapter.utils.test_equals import BaseEquals + + +class TestEqualsSQLServer(BaseEquals): + pass diff --git a/tests/functional/adapter/dbt/test_get_last_relation_modified.py b/tests/functional/adapter/dbt/test_get_last_relation_modified.py new file mode 100644 index 00000000..f141a700 --- /dev/null +++ b/tests/functional/adapter/dbt/test_get_last_relation_modified.py @@ -0,0 +1,59 @@ +import os + +import pytest +from dbt.cli.main import dbtRunner + +freshness_via_metadata_schema_yml = """version: 2 +sources: + - name: test_source + freshness: + warn_after: {count: 10, period: hour} + error_after: {count: 1, period: day} + schema: "{{ env_var('DBT_GET_LAST_RELATION_TEST_SCHEMA') }}" + tables: + - name: test_table +""" + + +class TestGetLastRelationModified: + @pytest.fixture(scope="class", autouse=True) + def set_env_vars(self, project): + os.environ["DBT_GET_LAST_RELATION_TEST_SCHEMA"] = project.test_schema + yield + del os.environ["DBT_GET_LAST_RELATION_TEST_SCHEMA"] + + @pytest.fixture(scope="class") + def models(self): + return {"schema.yml": freshness_via_metadata_schema_yml} + + @pytest.fixture(scope="class") + def custom_schema(self, project, set_env_vars): + with project.adapter.connection_named("__test"): + relation = project.adapter.Relation.create( + database=project.database, schema=os.environ["DBT_GET_LAST_RELATION_TEST_SCHEMA"] + ) + project.adapter.drop_schema(relation) + project.adapter.create_schema(relation) + + yield relation.schema + + with project.adapter.connection_named("__test"): + project.adapter.drop_schema(relation) + + def test_get_last_relation_modified(self, project, set_env_vars, custom_schema): + project.run_sql( + f"create table {custom_schema}.test_table (id int, name varchar(100) not null);" + ) + + warning_or_error = False + + def probe(e): + nonlocal warning_or_error + if e.info.level in ["warning", "error"]: + warning_or_error = True + + runner = dbtRunner(callbacks=[probe]) + runner.invoke(["source", "freshness"]) + + # The 'source freshness' command should succeed without warnings or errors. + assert not warning_or_error diff --git a/tests/functional/adapter/dbt/test_list_relations_without_caching.py b/tests/functional/adapter/dbt/test_list_relations_without_caching.py new file mode 100644 index 00000000..1e257bbc --- /dev/null +++ b/tests/functional/adapter/dbt/test_list_relations_without_caching.py @@ -0,0 +1,159 @@ +import json + +import pytest +from dbt.tests.util import run_dbt, run_dbt_and_capture + +NUM_VIEWS = 10 +NUM_EXPECTED_RELATIONS = 1 + NUM_VIEWS + +TABLE_BASE_SQL = """ +{{ config(materialized='table') }} + +select 1 as id +""".lstrip() + +VIEW_X_SQL = """ +select id from {{ ref('my_model_base') }} +""".lstrip() + +VALIDATE_LIST_RELATIONS_MACRO = """ +{% macro validate_list_relations_without_caching(schema_relation) -%} + + {% call statement('list_relations_without_caching', fetch_result=True) -%} + with base as ( + select + DB_NAME() as [database], + t.name as [name], + SCHEMA_NAME(t.schema_id) as [schema], + 'table' as table_type + from sys.tables as t {{ information_schema_hints() }} + union all + select + DB_NAME() as [database], + v.name as [name], + SCHEMA_NAME(v.schema_id) as [schema], + 'view' as table_type + from sys.views as v {{ information_schema_hints() }} + ) + select * from base + where [schema] like '{{ schema_relation }}' + {% endcall %} + + {% set relation_list_result = load_result('list_relations_without_caching').table %} + {% set n_relations = relation_list_result | length %} + {{ log("n_relations: " ~ n_relations) }} +{% endmacro %} +""" + + +def parse_json_logs(json_log_output): + parsed_logs = [] + for line in json_log_output.split("\n"): + try: + log = json.loads(line) + except ValueError: + continue + + parsed_logs.append(log) + + return parsed_logs + + +def find_result_in_parsed_logs(parsed_logs, result_name): + return next( + ( + item["data"]["msg"] + for item in parsed_logs + if result_name in item["data"].get("msg", "msg") + ), + False, + ) + + +def find_exc_info_in_parsed_logs(parsed_logs, exc_info_name): + return next( + ( + item["data"]["exc_info"] + for item in parsed_logs + if exc_info_name in item["data"].get("exc_info", "exc_info") + ), + False, + ) + + +class TestListRelationsWithoutCachingSingleSQLServer: + @pytest.fixture(scope="class") + def models(self): + my_models = {"my_model_base.sql": TABLE_BASE_SQL} + for view in range(0, NUM_VIEWS): + my_models.update({f"my_model_{view}.sql": VIEW_X_SQL}) + + return my_models + + @pytest.fixture(scope="class") + def macros(self): + return { + "validate_list_relations_without_caching.sql": VALIDATE_LIST_RELATIONS_MACRO, + } + + def test__sqlserver__list_relations_without_caching(self, project): + """ + validates that sqlserver__list_relations_without_caching + macro returns a single record + """ + run_dbt(["run", "-s", "my_model_base"]) + + # database = project.database + schemas = project.created_schemas + + for schema in schemas: + kwargs = {"schema_relation": schema} + _, log_output = run_dbt_and_capture( + [ + "--debug", + # "--log-format=json", + "run-operation", + "validate_list_relations_without_caching", + "--args", + str(kwargs), + ] + ) + assert "n_relations: 1" in log_output + + +class TestListRelationsWithoutCachingFullSQLServer: + @pytest.fixture(scope="class") + def models(self): + my_models = {"my_model_base.sql": TABLE_BASE_SQL} + for view in range(0, NUM_VIEWS): + my_models.update({f"my_model_{view}.sql": VIEW_X_SQL}) + + return my_models + + @pytest.fixture(scope="class") + def macros(self): + return { + "validate_list_relations_without_caching.sql": VALIDATE_LIST_RELATIONS_MACRO, + } + + def test_sqlserver__list_relations_without_caching(self, project): + # purpose of the first run is to create the replicated views in the target schema + run_dbt(["run"]) + + # database = project.database + schemas = project.created_schemas + + for schema in schemas: + # schema_relation = f"{database}.{schema}" + kwargs = {"schema_relation": schema} + _, log_output = run_dbt_and_capture( + [ + "--debug", + # "--log-format=json", + "run-operation", + "validate_list_relations_without_caching", + "--args", + str(kwargs), + ] + ) + assert f"n_relations: {NUM_EXPECTED_RELATIONS}" in log_output diff --git a/tests/functional/adapter/dbt/test_null_compare.py b/tests/functional/adapter/dbt/test_null_compare.py new file mode 100644 index 00000000..e25acd94 --- /dev/null +++ b/tests/functional/adapter/dbt/test_null_compare.py @@ -0,0 +1,9 @@ +from dbt.tests.adapter.utils.test_null_compare import BaseMixedNullCompare, BaseNullCompare + + +class TestMixedNullCompareSQLServer(BaseMixedNullCompare): + pass + + +class TestNullCompareSQLServer(BaseNullCompare): + pass diff --git a/tests/functional/adapter/dbt/test_relation_types.py b/tests/functional/adapter/dbt/test_relation_types.py new file mode 100644 index 00000000..af707f00 --- /dev/null +++ b/tests/functional/adapter/dbt/test_relation_types.py @@ -0,0 +1,60 @@ +import pytest +from dbt.artifacts.schemas.catalog import CatalogArtifact +from dbt.tests.util import run_dbt + +MY_SEED = """ +id,value +1,100 +2,200 +3,300 +""".strip() + + +MY_TABLE = """ +{{ config( + materialized='table', +) }} +select * from {{ ref('my_seed') }} +""" + + +MY_VIEW = """ +{{ config( + materialized='view', +) }} +select * from {{ ref('my_seed') }} +""" + + +class TestCatalogRelationTypes: + @pytest.fixture(scope="class", autouse=True) + def seeds(self): + return {"my_seed.csv": MY_SEED} + + @pytest.fixture(scope="class", autouse=True) + def models(self): + yield { + "my_table.sql": MY_TABLE, + "my_view.sql": MY_VIEW, + } + + @pytest.fixture(scope="class", autouse=True) + def docs(self, project): + run_dbt(["seed"]) + run_dbt(["run"]) + yield run_dbt(["docs", "generate"]) + + @pytest.mark.parametrize( + "node_name,relation_type", + [ + ("seed.test.my_seed", "BASE TABLE"), + ("model.test.my_table", "BASE TABLE"), + ("model.test.my_view", "VIEW"), + ], + ) + def test_relation_types_populate_correctly( + self, docs: CatalogArtifact, node_name: str, relation_type: str + ): + assert node_name in docs.nodes + node = docs.nodes[node_name] + assert node.metadata.type == relation_type diff --git a/tests/functional/adapter/dbt/test_schema.py b/tests/functional/adapter/dbt/test_schema.py new file mode 100644 index 00000000..6aa44c76 --- /dev/null +++ b/tests/functional/adapter/dbt/test_schema.py @@ -0,0 +1,36 @@ +import os + +import pytest +from dbt.tests.util import run_dbt + + +class TestSchemaCreation: + @pytest.fixture(scope="class") + def models(self): + return { + "dummy.sql": """ +{{ config(schema='with_custom_auth') }} +select 1 as id +""", + } + + @staticmethod + @pytest.fixture(scope="class") + def dbt_profile_target_update(): + return {"schema_authorization": "{{ env_var('DBT_TEST_USER_1') }}"} + + @staticmethod + def _verify_schema_owner(schema_name, owner, project): + get_schema_owner = f""" +select SCHEMA_OWNER from INFORMATION_SCHEMA.SCHEMATA where SCHEMA_NAME = '{schema_name}' + """ + result = project.run_sql(get_schema_owner, fetch="one")[0] + assert result == owner + + def test_schema_creation(self, project, unique_schema): + res = run_dbt(["run"]) + assert len(res) == 1 + + self._verify_schema_owner( + f"{unique_schema}_with_custom_auth", os.getenv("DBT_TEST_USER_1"), project + ) diff --git a/tests/functional/adapter/dbt/test_snapshot_configs.py b/tests/functional/adapter/dbt/test_snapshot_configs.py new file mode 100644 index 00000000..0585d64f --- /dev/null +++ b/tests/functional/adapter/dbt/test_snapshot_configs.py @@ -0,0 +1,780 @@ +# flake8: noqa: E501 +import datetime + +import pytest +from dbt.tests.util import ( + check_relations_equal, + get_manifest, + run_dbt, + run_dbt_and_capture, + run_sql_with_adapter, + update_config_file, +) + +model_seed_sql = """ +select * from "{{target.database}}".{{target.schema}}.seed +""" + +snapshots_multi_key_yml = """ +snapshots: + - name: snapshot_actual + relation: "ref('seed')" + config: + strategy: timestamp + updated_at: updated_at + unique_key: + - id1 + - id2 + snapshot_meta_column_names: + dbt_valid_to: test_valid_to + dbt_valid_from: test_valid_from + dbt_scd_id: test_scd_id + dbt_updated_at: test_updated_at +""" + +# multi-key snapshot fixtures + +create_multi_key_seed_sql = """ +create table {schema}.seed ( + id1 INTEGER, + id2 INTEGER, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(50), + gender VARCHAR(50), + ip_address VARCHAR(20), + updated_at DATETIME2(6) +); +""" + +create_multi_key_snapshot_expected_sql = """ +create table {schema}.snapshot_expected ( + id1 INTEGER, + id2 INTEGER, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(50), + gender VARCHAR(50), + ip_address VARCHAR(20), + + -- snapshotting fields + updated_at DATETIME2(6), + test_valid_from DATETIME2(6), + test_valid_to DATETIME2(6), + test_scd_id VARCHAR(50), + test_updated_at DATETIME2(6) +); +""" + +seed_multi_key_insert_sql = """ +-- seed inserts +-- use the same email for two users to verify that duplicated check_cols values +-- are handled appropriately +insert into {schema}.seed (id1, id2, first_name, last_name, email, gender, ip_address, updated_at) values +(1, 100, 'Judith', 'Kennedy', '(not provided)', 'Female', '54.60.24.128', '2015-12-24 12:19:28'), +(2, 200, 'Arthur', 'Kelly', '(not provided)', 'Male', '62.56.24.215', '2015-10-28 16:22:15'), +(3, 300, 'Rachel', 'Moreno', 'rmoreno2@msu.edu', 'Female', '31.222.249.23', '2016-04-05 02:05:30'), +(4, 400, 'Ralph', 'Turner', 'rturner3@hp.com', 'Male', '157.83.76.114', '2016-08-08 00:06:51'), +(5, 500, 'Laura', 'Gonzales', 'lgonzales4@howstuffworks.com', 'Female', '30.54.105.168', '2016-09-01 08:25:38'), +(6, 600, 'Katherine', 'Lopez', 'klopez5@yahoo.co.jp', 'Female', '169.138.46.89', '2016-08-30 18:52:11'), +(7, 700, 'Jeremy', 'Hamilton', 'jhamilton6@mozilla.org', 'Male', '231.189.13.133', '2016-07-17 02:09:46'), +(8, 800, 'Heather', 'Rose', 'hrose7@goodreads.com', 'Female', '87.165.201.65', '2015-12-29 22:03:56'), +(9, 900, 'Gregory', 'Kelly', 'gkelly8@trellian.com', 'Male', '154.209.99.7', '2016-03-24 21:18:16'), +(10, 1000, 'Rachel', 'Lopez', 'rlopez9@themeforest.net', 'Female', '237.165.82.71', '2016-08-20 15:44:49'), +(11, 1100, 'Donna', 'Welch', 'dwelcha@shutterfly.com', 'Female', '103.33.110.138', '2016-02-27 01:41:48'), +(12, 1200, 'Russell', 'Lawrence', 'rlawrenceb@qq.com', 'Male', '189.115.73.4', '2016-06-11 03:07:09'), +(13, 1300, 'Michelle', 'Montgomery', 'mmontgomeryc@scientificamerican.com', 'Female', '243.220.95.82', '2016-06-18 16:27:19'), +(14, 1400, 'Walter', 'Castillo', 'wcastillod@pagesperso-orange.fr', 'Male', '71.159.238.196', '2016-10-06 01:55:44'), +(15, 1500, 'Robin', 'Mills', 'rmillse@vkontakte.ru', 'Female', '172.190.5.50', '2016-10-31 11:41:21'), +(16, 1600, 'Raymond', 'Holmes', 'rholmesf@usgs.gov', 'Male', '148.153.166.95', '2016-10-03 08:16:38'), +(17, 1700, 'Gary', 'Bishop', 'gbishopg@plala.or.jp', 'Male', '161.108.182.13', '2016-08-29 19:35:20'), +(18, 1800, 'Anna', 'Riley', 'arileyh@nasa.gov', 'Female', '253.31.108.22', '2015-12-11 04:34:27'), +(19, 1900, 'Sarah', 'Knight', 'sknighti@foxnews.com', 'Female', '222.220.3.177', '2016-09-26 00:49:06'), +(20, 2000, 'Phyllis', 'Fox', null, 'Female', '163.191.232.95', '2016-08-21 10:35:19'); +""" + +populate_multi_key_snapshot_expected_sql = """ +-- populate snapshot table +insert into {schema}.snapshot_expected ( + id1, + id2, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + test_valid_from, + test_valid_to, + test_updated_at, + test_scd_id +) + +select + id1, + id2, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as test_valid_from, + cast(null as Datetime2(6)) as test_valid_to, + updated_at as test_updated_at, + convert( + varchar(50), + hashbytes( + 'md5', + coalesce(cast(id1 as varchar(8000)), '') + + '|' + + coalesce(cast(id2 as varchar(8000)), '') + + '|' + + coalesce(cast(updated_at as varchar(8000)), '') + ), + 2 + ) as test_scd_id +from {schema}.seed; +""" + +model_seed_sql = """ +select * from "{{target.database}}"."{{target.schema}}".seed +""" + +snapshots_multi_key_yml = """ +snapshots: + - name: snapshot_actual + relation: "ref('seed')" + config: + strategy: timestamp + updated_at: updated_at + unique_key: + - id1 + - id2 + snapshot_meta_column_names: + dbt_valid_to: test_valid_to + dbt_valid_from: test_valid_from + dbt_scd_id: test_scd_id + dbt_updated_at: test_updated_at +""" + +invalidate_multi_key_sql = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = CAST(DATEADD(HOUR, 1, updated_at) AS datetime2(6)), + email = case when id1 = 20 then 'pfoxj@creativecommons.org' else 'new_' + email end +where id1 >= 10 and id1 <= 20; + + +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = CAST(DATEADD(HOUR, 1, updated_at) AS datetime2(6)) +where id1 >= 10 and id1 <= 20; + +""" + +update_multi_key_sql = """ +-- insert v2 of the 11 - 21 records + +insert into {schema}.snapshot_expected ( + id1, + id2, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + test_valid_from, + test_valid_to, + test_updated_at, + test_scd_id +) + +select + id1, + id2, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as test_valid_from, + cast(null as Datetime2(6)) as test_valid_to, + updated_at as test_updated_at, + convert( + varchar(50), + hashbytes( + 'md5', + coalesce(cast(id1 as varchar(8000)), '') + + '|' + + coalesce(cast(id2 as varchar(8000)), '') + + '|' + + coalesce(cast(updated_at as varchar(8000)), '') + ), + 2 + ) as test_scd_id +from {schema}.seed +where id1 >= 10 and id1 <= 20; +""" + +snapshot_actual_sql = """ +{% snapshot snapshot_actual %} + + {{ + config( + unique_key='cast(id as varchar(8000)) + '~ "'-'" ~ ' + cast(first_name as varchar(8000))', + ) + }} + + select * from "{{target.database}}"."{{target.schema}}".seed + +{% endsnapshot %} +""" + +snapshots_valid_to_current_yml = """ +snapshots: + - name: snapshot_actual + config: + strategy: timestamp + updated_at: updated_at + dbt_valid_to_current: "cast('2099-12-31' as date)" + snapshot_meta_column_names: + dbt_valid_to: test_valid_to + dbt_valid_from: test_valid_from + dbt_scd_id: test_scd_id + dbt_updated_at: test_updated_at +""" + +ref_snapshot_sql = """ +select * from {{ ref('snapshot_actual') }} +""" + +create_seed_sql = """ +create table {schema}.seed ( + id INT, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(50), + gender VARCHAR(50), + ip_address VARCHAR(20), + updated_at DATETIME2(6) +); +""" + +create_snapshot_expected_sql = """ +create table {schema}.snapshot_expected ( + id INT, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(50), + gender VARCHAR(50), + ip_address VARCHAR(20), + + -- snapshotting fields + updated_at DATETIME2(6), + test_valid_from DATETIME2(6), + test_valid_to DATETIME2(6), + test_scd_id VARCHAR(50), + test_updated_at DATETIME2(6) +); +""" + +seed_insert_sql = """ +-- seed inserts +-- use the same email for two users to verify that duplicated check_cols values +-- are handled appropriately +insert into {schema}.seed (id, first_name, last_name, email, gender, ip_address, updated_at) values +(1, 'Judith', 'Kennedy', '(not provided)', 'Female', '54.60.24.128', '2015-12-24 12:19:28'), +(2, 'Arthur', 'Kelly', '(not provided)', 'Male', '62.56.24.215', '2015-10-28 16:22:15'), +(3, 'Rachel', 'Moreno', 'rmoreno2@msu.edu', 'Female', '31.222.249.23', '2016-04-05 02:05:30'), +(4, 'Ralph', 'Turner', 'rturner3@hp.com', 'Male', '157.83.76.114', '2016-08-08 00:06:51'), +(5, 'Laura', 'Gonzales', 'lgonzales4@howstuffworks.com', 'Female', '30.54.105.168', '2016-09-01 08:25:38'), +(6, 'Katherine', 'Lopez', 'klopez5@yahoo.co.jp', 'Female', '169.138.46.89', '2016-08-30 18:52:11'), +(7, 'Jeremy', 'Hamilton', 'jhamilton6@mozilla.org', 'Male', '231.189.13.133', '2016-07-17 02:09:46'), +(8, 'Heather', 'Rose', 'hrose7@goodreads.com', 'Female', '87.165.201.65', '2015-12-29 22:03:56'), +(9, 'Gregory', 'Kelly', 'gkelly8@trellian.com', 'Male', '154.209.99.7', '2016-03-24 21:18:16'), +(10, 'Rachel', 'Lopez', 'rlopez9@themeforest.net', 'Female', '237.165.82.71', '2016-08-20 15:44:49'), +(11, 'Donna', 'Welch', 'dwelcha@shutterfly.com', 'Female', '103.33.110.138', '2016-02-27 01:41:48'), +(12, 'Russell', 'Lawrence', 'rlawrenceb@qq.com', 'Male', '189.115.73.4', '2016-06-11 03:07:09'), +(13, 'Michelle', 'Montgomery', 'mmontgomeryc@scientificamerican.com', 'Female', '243.220.95.82', '2016-06-18 16:27:19'), +(14, 'Walter', 'Castillo', 'wcastillod@pagesperso-orange.fr', 'Male', '71.159.238.196', '2016-10-06 01:55:44'), +(15, 'Robin', 'Mills', 'rmillse@vkontakte.ru', 'Female', '172.190.5.50', '2016-10-31 11:41:21'), +(16, 'Raymond', 'Holmes', 'rholmesf@usgs.gov', 'Male', '148.153.166.95', '2016-10-03 08:16:38'), +(17, 'Gary', 'Bishop', 'gbishopg@plala.or.jp', 'Male', '161.108.182.13', '2016-08-29 19:35:20'), +(18, 'Anna', 'Riley', 'arileyh@nasa.gov', 'Female', '253.31.108.22', '2015-12-11 04:34:27'), +(19, 'Sarah', 'Knight', 'sknighti@foxnews.com', 'Female', '222.220.3.177', '2016-09-26 00:49:06'), +(20, 'Phyllis', 'Fox', null, 'Female', '163.191.232.95', '2016-08-21 10:35:19'); +""" + +populate_snapshot_expected_valid_to_current_sql = """ +-- populate snapshot table +insert into {schema}.snapshot_expected ( + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + test_valid_from, + test_valid_to, + test_updated_at, + test_scd_id +) + +select + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as test_valid_from, + cast('2099-12-31' as date) as test_valid_to, + updated_at as test_updated_at, + convert( + varchar(50), + hashbytes( + 'md5', + coalesce(cast(id as varchar(8000)), '') + + '-' + + coalesce(cast(first_name as varchar(8000)), '') + + '|' + + coalesce(cast(updated_at as varchar(8000)), '') + ), + 2 + ) as test_scd_id +from {schema}.seed; +""" + +populate_snapshot_expected_sql = """ +-- populate snapshot table +insert into {schema}.snapshot_expected ( + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + test_valid_from, + test_valid_to, + test_updated_at, + test_scd_id +) + +select + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as test_valid_from, + cast(null as date) as test_valid_to, + updated_at as test_updated_at, + convert( + varchar(50), + hashbytes( + 'md5', + coalesce(cast(id as varchar(8000)), '') + + '-' + + coalesce(cast(first_name as varchar(8000)), '') + + '|' + + coalesce(cast(updated_at as varchar(8000)), '') + ), + 2 + ) as test_scd_id +from {schema}.seed; +""" + +update_with_current_sql = """ +-- insert v2 of the 11 - 21 records + +insert into {schema}.snapshot_expected ( + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + test_valid_from, + test_valid_to, + test_updated_at, + test_scd_id +) + +select + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as test_valid_from, + cast('2099-12-31' as date) as test_valid_to, + updated_at as test_updated_at, + convert( + varchar(50), + hashbytes( + 'md5', + coalesce(cast(id as varchar(8000)), '') + + '-' + + coalesce(cast(first_name as varchar(8000)), '') + + '|' + + coalesce(cast(updated_at as varchar(8000)), '') + ), + 2 + ) as test_scd_id +from {schema}.seed +where id >= 10 and id <= 20; +""" + +invalidate_sql = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set + updated_at = CAST(DATEADD(HOUR, 1, updated_at) AS datetime2(6)), + email = case when id = 20 then 'pfoxj@creativecommons.org' else 'new_' + email end +where id >= 10 and id <= 20; + +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set + test_valid_to = CAST(DATEADD(HOUR, 1, updated_at) AS datetime2(6)) +where id >= 10 and id <= 20; +""" + +snapshots_no_column_names_yml = """ +snapshots: + - name: snapshot_actual + config: + strategy: timestamp + updated_at: updated_at +""" + +ref_snapshot_sql = """ +select * from {{ ref('snapshot_actual') }} +""" + +update_sql = """ +-- insert v2 of the 11 - 21 records + +insert into {schema}.snapshot_expected ( + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + test_valid_from, + test_valid_to, + test_updated_at, + test_scd_id +) + +select + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as test_valid_from, + cast (null as date) as test_valid_to, + updated_at as test_updated_at, + convert( + varchar(50), + hashbytes( + 'md5', + coalesce(cast(id as varchar(8000)), '') + + '-' + + coalesce(cast(first_name as varchar(8000)), '') + + '|' + + coalesce(cast(updated_at as varchar(8000)), '') + ), + 2 + ) as test_scd_id +from {schema}.seed +where id >= 10 and id <= 20; +""" + +snapshots_yml = """ +snapshots: + - name: snapshot_actual + config: + strategy: timestamp + updated_at: updated_at + snapshot_meta_column_names: + dbt_valid_to: test_valid_to + dbt_valid_from: test_valid_from + dbt_scd_id: test_scd_id + dbt_updated_at: test_updated_at +""" + + +class BaseSnapshotDbtValidToCurrent: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_valid_to_current_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_valid_to_current(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_valid_to_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + original_snapshot = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + assert original_snapshot[0][2] == datetime.datetime(2099, 12, 31, 0, 0) + assert original_snapshot[9][2] == datetime.datetime(2099, 12, 31, 0, 0) + + project.run_sql(invalidate_sql) + project.run_sql(update_with_current_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + updated_snapshot = run_sql_with_adapter( + project.adapter, + "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", + "all", + ) + assert len(updated_snapshot) == 31 + + updated_snapshot_row_count = run_sql_with_adapter( + project.adapter, + "select count(*) from {schema}.snapshot_actual where test_valid_to != '2099-12-31 00:00:00.000000'", + "all", + ) + assert updated_snapshot_row_count[0][0] == 11 + + updated_snapshot_row_17 = run_sql_with_adapter( + project.adapter, + "select id from {schema}.snapshot_actual where test_valid_to = '2016-08-29 20:35:20.000000'", + "all", + ) + assert updated_snapshot_row_17[0][0] == 17 + + updated_snapshot_row_16 = run_sql_with_adapter( + project.adapter, + "select id from {schema}.snapshot_actual where test_valid_to = '2016-10-03 09:16:38.000000'", + "all", + ) + assert updated_snapshot_row_16[0][0] == 16 + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class TestSnapshotDbtValidToCurrent(BaseSnapshotDbtValidToCurrent): + pass + + +class BaseSnapshotColumnNamesFromDbtProject: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_column_names_from_project(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class TestBaseSnapshotColumnNamesFromDbtProject(BaseSnapshotColumnNamesFromDbtProject): + pass + + +class BaseSnapshotColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_snapshot_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class TestBaseSnapshotColumnNames(BaseSnapshotColumnNames): + pass + + +class BaseSnapshotInvalidColumnNames: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": snapshots_no_column_names_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + } + } + } + + def test_snapshot_invalid_column_names(self, project): + project.run_sql(create_seed_sql) + project.run_sql(create_snapshot_expected_sql) + project.run_sql(seed_insert_sql) + project.run_sql(populate_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + manifest = get_manifest(project.project_root) + snapshot_node = manifest.nodes["snapshot.test.snapshot_actual"] + snapshot_node.config.snapshot_meta_column_names == { + "dbt_valid_to": "test_valid_to", + "dbt_valid_from": "test_valid_from", + "dbt_scd_id": "test_scd_id", + "dbt_updated_at": "test_updated_at", + } + + project.run_sql(invalidate_sql) + project.run_sql(update_sql) + + # Change snapshot_meta_columns and look for an error + different_columns = { + "snapshots": { + "test": { + "+snapshot_meta_column_names": { + "dbt_valid_to": "test_valid_to", + "dbt_updated_at": "test_updated_at", + } + } + } + } + update_config_file(different_columns, "dbt_project.yml") + + results, log_output = run_dbt_and_capture(["snapshot"], expect_pass=False) + assert len(results) == 1 + assert "dbt_scd_id" in log_output + assert "1 of 1 ERROR snapshotting test" in log_output + + +class TestBaseSnapshotInvalidColumnNames(BaseSnapshotInvalidColumnNames): + pass + + +# This uses snapshot_meta_column_names, yaml-only snapshot def, +# and multiple keys +class BaseSnapshotMultiUniqueKey: + @pytest.fixture(scope="class") + def models(self): + return { + "seed.sql": model_seed_sql, + "snapshots.yml": snapshots_multi_key_yml, + "ref_snapshot.sql": ref_snapshot_sql, + } + + def test_multi_column_unique_key(self, project): + project.run_sql(create_multi_key_seed_sql) + project.run_sql(create_multi_key_snapshot_expected_sql) + project.run_sql(seed_multi_key_insert_sql) + project.run_sql(populate_multi_key_snapshot_expected_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_multi_key_sql) + project.run_sql(update_multi_key_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + +class TestBaseSnapshotMultiUniqueKey(BaseSnapshotMultiUniqueKey): + pass diff --git a/tests/functional/adapter/dbt/test_snapshot_new_record_mode.py b/tests/functional/adapter/dbt/test_snapshot_new_record_mode.py new file mode 100644 index 00000000..4b60ef50 --- /dev/null +++ b/tests/functional/adapter/dbt/test_snapshot_new_record_mode.py @@ -0,0 +1,226 @@ +# flake8: noqa: E501 +import pytest +from dbt.tests.util import check_relations_equal, run_dbt + +_seed_new_record_mode = """ +create table {database}.{schema}.seed ( + id INTEGER, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(50), + gender VARCHAR(50), + ip_address VARCHAR(20), + updated_at DATETIME2(6) +); + +create table {database}.{schema}.snapshot_expected ( + id INTEGER, + first_name VARCHAR(50), + last_name VARCHAR(50), + email VARCHAR(50), + gender VARCHAR(50), + ip_address VARCHAR(20), + + -- snapshotting fields + updated_at DATETIME2(6), + dbt_valid_from DATETIME2(6), + dbt_valid_to DATETIME2(6), + dbt_scd_id VARCHAR(50), + dbt_updated_at DATETIME2(6), + dbt_is_deleted VARCHAR(50) +); + + +-- seed inserts +-- use the same email for two users to verify that duplicated check_cols values +-- are handled appropriately +insert into {database}.{schema}.seed (id, first_name, last_name, email, gender, ip_address, updated_at) values +(1, 'Judith', 'Kennedy', '(not provided)', 'Female', '54.60.24.128', '2015-12-24 12:19:28'), +(2, 'Arthur', 'Kelly', '(not provided)', 'Male', '62.56.24.215', '2015-10-28 16:22:15'), +(3, 'Rachel', 'Moreno', 'rmoreno2@msu.edu', 'Female', '31.222.249.23', '2016-04-05 02:05:30'), +(4, 'Ralph', 'Turner', 'rturner3@hp.com', 'Male', '157.83.76.114', '2016-08-08 00:06:51'), +(5, 'Laura', 'Gonzales', 'lgonzales4@howstuffworks.com', 'Female', '30.54.105.168', '2016-09-01 08:25:38'), +(6, 'Katherine', 'Lopez', 'klopez5@yahoo.co.jp', 'Female', '169.138.46.89', '2016-08-30 18:52:11'), +(7, 'Jeremy', 'Hamilton', 'jhamilton6@mozilla.org', 'Male', '231.189.13.133', '2016-07-17 02:09:46'), +(8, 'Heather', 'Rose', 'hrose7@goodreads.com', 'Female', '87.165.201.65', '2015-12-29 22:03:56'), +(9, 'Gregory', 'Kelly', 'gkelly8@trellian.com', 'Male', '154.209.99.7', '2016-03-24 21:18:16'), +(10, 'Rachel', 'Lopez', 'rlopez9@themeforest.net', 'Female', '237.165.82.71', '2016-08-20 15:44:49'), +(11, 'Donna', 'Welch', 'dwelcha@shutterfly.com', 'Female', '103.33.110.138', '2016-02-27 01:41:48'), +(12, 'Russell', 'Lawrence', 'rlawrenceb@qq.com', 'Male', '189.115.73.4', '2016-06-11 03:07:09'), +(13, 'Michelle', 'Montgomery', 'mmontgomeryc@scientificamerican.com', 'Female', '243.220.95.82', '2016-06-18 16:27:19'), +(14, 'Walter', 'Castillo', 'wcastillod@pagesperso-orange.fr', 'Male', '71.159.238.196', '2016-10-06 01:55:44'), +(15, 'Robin', 'Mills', 'rmillse@vkontakte.ru', 'Female', '172.190.5.50', '2016-10-31 11:41:21'), +(16, 'Raymond', 'Holmes', 'rholmesf@usgs.gov', 'Male', '148.153.166.95', '2016-10-03 08:16:38'), +(17, 'Gary', 'Bishop', 'gbishopg@plala.or.jp', 'Male', '161.108.182.13', '2016-08-29 19:35:20'), +(18, 'Anna', 'Riley', 'arileyh@nasa.gov', 'Female', '253.31.108.22', '2015-12-11 04:34:27'), +(19, 'Sarah', 'Knight', 'sknighti@foxnews.com', 'Female', '222.220.3.177', '2016-09-26 00:49:06'), +(20, 'Phyllis', 'Fox', null, 'Female', '163.191.232.95', '2016-08-21 10:35:19'); + + +-- populate snapshot table +insert into {database}.{schema}.snapshot_expected ( + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + dbt_valid_from, + dbt_valid_to, + dbt_updated_at, + dbt_scd_id, + dbt_is_deleted +) + +select + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as dbt_valid_from, + cast(null as date) as dbt_valid_to, + updated_at as dbt_updated_at, + convert(varchar(50), hashbytes('md5', coalesce(cast(id as varchar(8000)), '') + '-' + coalesce(cast(first_name as varchar(8000)), '') + '|' + coalesce(cast(updated_at as varchar(8000)), '')), 2) as dbt_scd_id, + 'False' as dbt_is_deleted +from {database}.{schema}.seed; +""" + +_snapshot_actual_sql = """ +{% snapshot snapshot_actual %} + + {{ + config( + unique_key='cast(id as varchar(8000)) + '~ "'-'" ~ ' + cast(first_name as varchar(8000))', + ) + }} + select * from "{{target.database}}"."{{target.schema}}".seed + +{% endsnapshot %} +""" + +_snapshots_yml = """ +snapshots: + - name: snapshot_actual + config: + strategy: timestamp + updated_at: updated_at + hard_deletes: new_record +""" + +_ref_snapshot_sql = """ +select * from {{ ref('snapshot_actual') }} +""" + + +_invalidate_sql = """ +-- update records 11 - 21. Change email and updated_at field +update {schema}.seed set +updated_at = CAST(DATEADD(HOUR, 1, updated_at) AS datetime2(6)), +email = case when id = 20 then 'pfoxj@creativecommons.org' else 'new_' + email end +where id >= 10 and id <= 20; + + +-- invalidate records 11 - 21 +update {schema}.snapshot_expected set +dbt_valid_to = CAST(DATEADD(HOUR, 1, updated_at) AS datetime2(6)) +where id >= 10 and id <= 20; + +""" + +_update_sql = """ +-- insert v2 of the 11 - 21 records + +insert into {database}.{schema}.snapshot_expected ( + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + dbt_valid_from, + dbt_valid_to, + dbt_updated_at, + dbt_scd_id, + dbt_is_deleted +) + +select + id, + first_name, + last_name, + email, + gender, + ip_address, + updated_at, + -- fields added by snapshotting + updated_at as dbt_valid_from, + cast(null as date) as dbt_valid_to, + updated_at as dbt_updated_at, + convert(varchar(50), hashbytes('md5', coalesce(cast(id as varchar(8000)), '') + '-' + coalesce(cast(first_name as varchar(8000)), '') + '|' + coalesce(cast(updated_at as varchar(8000)), '')), 2) as dbt_scd_id, + 'False' as dbt_is_deleted +from {database}.{schema}.seed +where id >= 10 and id <= 20; +""" + +_delete_sql = """ +delete from {schema}.seed where id = 1 +""" + + +class SnapshotNewRecordMode: + @pytest.fixture(scope="class") + def snapshots(self): + return {"snapshot.sql": _snapshot_actual_sql} + + @pytest.fixture(scope="class") + def models(self): + return { + "snapshots.yml": _snapshots_yml, + "ref_snapshot.sql": _ref_snapshot_sql, + } + + @pytest.fixture(scope="class") + def seed_new_record_mode(self): + return _seed_new_record_mode + + @pytest.fixture(scope="class") + def invalidate_sql(self): + return _invalidate_sql + + @pytest.fixture(scope="class") + def update_sql(self): + return _update_sql + + @pytest.fixture(scope="class") + def delete_sql(self): + return _delete_sql + + def test_snapshot_new_record_mode( + self, project, seed_new_record_mode, invalidate_sql, update_sql + ): + project.run_sql(seed_new_record_mode) + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + project.run_sql(invalidate_sql) + project.run_sql(update_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) + + project.run_sql(_delete_sql) + + results = run_dbt(["snapshot"]) + assert len(results) == 1 + + +class TestSnapshotNewRecordMode(SnapshotNewRecordMode): + pass diff --git a/tests/functional/adapter/dbt/test_sources.py b/tests/functional/adapter/dbt/test_sources.py new file mode 100644 index 00000000..4c1a214d --- /dev/null +++ b/tests/functional/adapter/dbt/test_sources.py @@ -0,0 +1,69 @@ +import pytest +from dbt.tests.adapter.basic.files import config_materialized_table, config_materialized_view +from dbt.tests.util import run_dbt + +source_regular = """ +version: 2 +sources: +- name: regular + schema: sys + tables: + - name: tables + columns: + - name: name + tests: + - not_null +""" + +source_space_in_name = """ +version: 2 +sources: +- name: 'space in name' + schema: sys + tables: + - name: tables + columns: + - name: name + tests: + - not_null +""" + +select_from_source_regular = """ +select object_id,schema_id from {{ source("regular", "tables") }} +""" + +select_from_source_space_in_name = """ +select object_id,schema_id from {{ source("space in name", "tables") }} +""" + + +# System tables are not supported for data type reasons. +@pytest.mark.skip( + reason="The query references an object that is not supported in distributed processing mode." +) +class TestSourcesSQLServer: + @pytest.fixture(scope="class") + def models(self): + return { + "source_regular.yml": source_regular, + "source_space_in_name.yml": source_space_in_name, + "v_select_from_source_regular.sql": config_materialized_view + + select_from_source_regular, + "v_select_from_source_space_in_name.sql": config_materialized_view + + select_from_source_space_in_name, + "t_select_from_source_regular.sql": config_materialized_table + + select_from_source_regular, + "t_select_from_source_space_in_name.sql": config_materialized_table + + select_from_source_space_in_name, + } + + def test_dbt_run(self, project): + run_dbt(["compile"]) + + ls = run_dbt(["list"]) + assert len(ls) == 8 + ls_sources = [src for src in ls if src.startswith("source:")] + assert len(ls_sources) == 2 + + run_dbt(["run"]) + run_dbt(["test"]) diff --git a/tests/functional/adapter/dbt/test_store_test_failures.py b/tests/functional/adapter/dbt/test_store_test_failures.py new file mode 100644 index 00000000..a95f5c38 --- /dev/null +++ b/tests/functional/adapter/dbt/test_store_test_failures.py @@ -0,0 +1,172 @@ +# flake8: noqa: E501 +import pytest +from dbt.tests.adapter.store_test_failures_tests import basic, fixtures +from dbt.tests.util import check_relations_equal, run_dbt + +# used to rename test audit schema to help test schema meet max char limit +# the default is _dbt_test__audit but this runs over the postgres 63 schema name char limit +# without which idempotency conditions will not hold (i.e. dbt can't drop the schema properly) +TEST_AUDIT_SCHEMA_SUFFIX = "dbt_test__aud" + +tests__passing_test = """ +select * from {{ ref('fine_model') }} +where 1=2 +""" + + +class StoreTestFailuresBase: + @pytest.fixture(scope="function", autouse=True) + def setUp(self, project): + self.test_audit_schema = f"{project.test_schema}_{TEST_AUDIT_SCHEMA_SUFFIX}" + run_dbt(["seed"]) + run_dbt(["run"]) + + @pytest.fixture(scope="class") + def seeds(self): + return { + "people.csv": fixtures.seeds__people, + "expected_accepted_values.csv": fixtures.seeds__expected_accepted_values, + "expected_failing_test.csv": fixtures.seeds__expected_failing_test, + "expected_not_null_problematic_model_id.csv": fixtures.seeds__expected_not_null_problematic_model_id, + "expected_unique_problematic_model_id.csv": fixtures.seeds__expected_unique_problematic_model_id, + } + + @pytest.fixture(scope="class") + def tests(self): + return { + "failing_test.sql": fixtures.tests__failing_test, + "passing_test.sql": tests__passing_test, + } + + @pytest.fixture(scope="class") + def properties(self): + return {"schema.yml": fixtures.properties__schema_yml} + + @pytest.fixture(scope="class") + def models(self): + return { + "fine_model.sql": fixtures.models__fine_model, + "fine_model_but_with_a_no_good_very_long_name.sql": fixtures.models__file_model_but_with_a_no_good_very_long_name, + "problematic_model.sql": fixtures.models__problematic_model, + } + + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "seeds": { + "quote_columns": False, + "test": self.column_type_overrides(), + }, + "data_tests": {"+schema": TEST_AUDIT_SCHEMA_SUFFIX}, + } + + def column_type_overrides(self): + return {} + + def run_tests_store_one_failure(self, project): + run_dbt(["test"], expect_pass=False) + + # one test is configured with store_failures: true, make sure it worked + check_relations_equal( + project.adapter, + [ + f"{self.test_audit_schema}.unique_problematic_model_id", + "expected_unique_problematic_model_id", + ], + ) + + def run_tests_store_failures_and_assert(self, project): + # make sure this works idempotently for all tests + run_dbt(["test", "--store-failures"], expect_pass=False) + results = run_dbt(["test", "--store-failures"], expect_pass=False) + + # compare test results + actual = [(r.status.value, r.failures) for r in results] + expected = [ + ("error", None), + ("pass", 0), + ("pass", 0), + ("pass", 0), + ("error", None), + ("fail", 2), + ("fail", 2), + ("fail", 10), + ] + assert sorted(actual) == sorted(expected) + + # compare test results stored in database + check_relations_equal( + project.adapter, [f"{self.test_audit_schema}.failing_test", "expected_failing_test"] + ) + check_relations_equal( + project.adapter, + [ + f"{self.test_audit_schema}.not_null_problematic_model_id", + "expected_not_null_problematic_model_id", + ], + ) + check_relations_equal( + project.adapter, + [ + f"{self.test_audit_schema}.unique_problematic_model_id", + "expected_unique_problematic_model_id", + ], + ) + + +class BaseStoreTestFailures(StoreTestFailuresBase): + @pytest.fixture(scope="function") + def clean_up(self, project): + yield + with project.adapter.connection_named("__test"): + relation = project.adapter.Relation.create( + database=project.database, schema=self.test_audit_schema + ) + project.adapter.drop_schema(relation) + + relation = project.adapter.Relation.create( + database=project.database, schema=project.test_schema + ) + project.adapter.drop_schema(relation) + + def column_type_overrides(self): + return { + "expected_unique_problematic_model_id": { + "+column_types": { + "n_records": "bigint", + }, + }, + "expected_accepted_values": { + "+column_types": { + "n_records": "bigint", + }, + }, + } + + def test__store_and_assert(self, project, clean_up): + self.run_tests_store_one_failure(project) + self.run_tests_store_failures_and_assert(project) + + +class TestStoreTestFailures(BaseStoreTestFailures): + pass + + +class TestStoreTestFailuresAsProjectLevelOff(basic.StoreTestFailuresAsProjectLevelOff): + pass + + +class TestStoreTestFailuresAsProjectLevelView(basic.StoreTestFailuresAsProjectLevelView): + pass + + +class TestStoreTestFailuresAsGeneric(basic.StoreTestFailuresAsGeneric): + pass + + +class TestStoreTestFailuresAsProjectLevelEphemeral(basic.StoreTestFailuresAsProjectLevelEphemeral): + pass + + +class TestStoreTestFailuresAsExceptions(basic.StoreTestFailuresAsExceptions): + pass diff --git a/tests/functional/adapter/dbt/test_timestamps.py b/tests/functional/adapter/dbt/test_timestamps.py new file mode 100644 index 00000000..3c2ef343 --- /dev/null +++ b/tests/functional/adapter/dbt/test_timestamps.py @@ -0,0 +1,18 @@ +import pytest +from dbt.tests.adapter.utils.test_timestamps import BaseCurrentTimestamps + + +class TestCurrentTimestampSQLServer(BaseCurrentTimestamps): + @pytest.fixture(scope="class") + def models(self): + return { + "get_current_timestamp.sql": 'select {{ current_timestamp() }} as "current_timestamp"' + } + + @pytest.fixture(scope="class") + def expected_schema(self): + return {"current_timestamp": "datetime2(6)"} + + @pytest.fixture(scope="class") + def expected_sql(self): + return '''select CAST(SYSDATETIME() AS DATETIME2(6)) as "current_timestamp"''' diff --git a/tests/functional/adapter/dbt/test_utils.py b/tests/functional/adapter/dbt/test_utils.py index b5ec8d30..4d458701 100644 --- a/tests/functional/adapter/dbt/test_utils.py +++ b/tests/functional/adapter/dbt/test_utils.py @@ -105,9 +105,7 @@ class TestConcat(BaseConcat): pass -@pytest.mark.skip( - reason="Only should implement Aware or Naive. Opted for Naive to align with fabric." -) +@pytest.mark.skip(reason="Only should implement Aware or Naive. Opted for Naive.") class TestCurrentTimestampAware(BaseCurrentTimestampAware): pass