diff --git a/docs/integrations/engines/azuresql.md b/docs/integrations/engines/azuresql.md index 5b54ffa9c6..9a40603e6e 100644 --- a/docs/integrations/engines/azuresql.md +++ b/docs/integrations/engines/azuresql.md @@ -14,6 +14,17 @@ pip install "sqlmesh[azuresql]" ``` pip install "sqlmesh[azuresql-odbc]" ``` +Set `driver: "pyodbc"` in your connection options. + + +#### Python Driver (Official Microsoft driver for Azure SQL): +See [`mssql-python`](https://pypi.org/project/mssql-python/) for more information. +``` +pip install "sqlmesh[azuresql-mssql-python]" +``` + +Set `driver: "mssql-python"` in your connection options. + ### Connection options diff --git a/docs/integrations/engines/fabric.md b/docs/integrations/engines/fabric.md index 90ac3234fc..99102ffdac 100644 --- a/docs/integrations/engines/fabric.md +++ b/docs/integrations/engines/fabric.md @@ -14,6 +14,13 @@ NOTE: Fabric Warehouse is not recommended to be used for the SQLMesh [state conn pip install "sqlmesh[fabric]" ``` +#### Python Driver (Official Microsoft driver for Azure SQL): +See [`mssql-python`](https://pypi.org/project/mssql-python/) for more information. +``` +pip install "sqlmesh[fabric-mssql-python]" +``` +Set `driver: "mssql-python"` in your connection options. + ### Connection options | Option | Description | Type | Required | diff --git a/docs/integrations/engines/mssql.md b/docs/integrations/engines/mssql.md index 4c68219dd2..1aaae23a39 100644 --- a/docs/integrations/engines/mssql.md +++ b/docs/integrations/engines/mssql.md @@ -6,10 +6,20 @@ ``` pip install "sqlmesh[mssql]" ``` + ### Microsoft Entra ID / Azure Active Directory Authentication: ``` pip install "sqlmesh[mssql-odbc]" ``` +Set `driver: "pyodbc"` in your connection options. + +#### Python Driver (Official Microsoft driver for Azure SQL): +See [`mssql-python`](https://pypi.org/project/mssql-python/) for more information. +``` +pip install "sqlmesh[mssql-python]" +``` +Set `driver: "mssql-python"` in your connection options. + ## Incremental by unique key `MERGE` diff --git a/pyproject.toml b/pyproject.toml index 97c190a290..0372b69ae1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ classifiers = [ athena = ["PyAthena[Pandas]"] azuresql = ["pymssql"] azuresql-odbc = ["pyodbc>=5.0.0"] +azuresql-mssql-python = ["mssql-python>=1.1.0;python_version>=\"3.10\""] bigquery = [ "google-cloud-bigquery[pandas]", "google-cloud-bigquery-storage" @@ -84,6 +85,7 @@ dev = [ "PyAthena[Pandas]", "PyGithub>=2.6.0", "pyodbc>=5.0.0", + "mssql-python>=1.1.0;python_version>=\"3.10\"", "pyperf", "pyspark~=3.5.0", "pytest", @@ -109,11 +111,13 @@ dbt = ["dbt-core<2"] dlt = ["dlt"] duckdb = [] fabric = ["pyodbc>=5.0.0"] +fabric-mssql-python = ["mssql-python>=1.1.0;python_version>=\"3.10\""] gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] github = ["PyGithub>=2.6.0"] motherduck = ["duckdb>=1.2.0"] mssql = ["pymssql"] mssql-odbc = ["pyodbc>=5.0.0"] +mssql-python = ["mssql-python>=1.1.0;python_version>=\"3.10\""] mysql = ["pymysql"] mwaa = ["boto3"] postgres = ["psycopg2"] @@ -215,6 +219,7 @@ module = [ "mysql.*", "pymssql.*", "pyodbc.*", + "mssql_python.*", "psycopg2.*", "pytest_lazyfixture.*", "dbt.adapters.*", diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 638f0c28c8..fae859b3b7 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -2,42 +2,43 @@ import abc import base64 +import importlib import logging import os -import importlib import pathlib import re import typing as t from enum import Enum from functools import partial +from sys import version_info import pydantic +from packaging import version from pydantic import Field from pydantic_core import from_json -from packaging import version from sqlglot import exp -from sqlglot.helper import subclasses from sqlglot.errors import ParseError +from sqlglot.helper import subclasses from sqlmesh.core import engine_adapter from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.config.common import ( + compile_regex_mapping, concurrent_tasks_validator, http_headers_validator, - compile_regex_mapping, ) -from sqlmesh.core.engine_adapter.shared import CatalogSupport from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.engine_adapter.shared import CatalogSupport from sqlmesh.utils import debug_mode_enabled, str_to_bool +from sqlmesh.utils.aws import validate_s3_uri from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import ( ValidationInfo, field_validator, + get_concrete_types_from_typehint, model_validator, validation_error_message, - get_concrete_types_from_typehint, ) -from sqlmesh.utils.aws import validate_s3_uri if t.TYPE_CHECKING: from sqlmesh.core._typing import Self @@ -60,6 +61,7 @@ } MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)") PASSWORD_REGEX = re.compile(r"(password=)(\S+)") +SUPPORTS_MSSQL_PYTHON_DRIVER = (version_info.major, version_info.minor) >= (3, 10) def _get_engine_import_validator( @@ -955,7 +957,7 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: # if a client_secret exists, then a client_id also exists and we are using M2M # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py - from databricks.sdk.core import oauth_service_principal, Config + from databricks.sdk.core import Config, oauth_service_principal config = Config( host=f"https://{self.server_hostname}", @@ -1110,8 +1112,8 @@ def _engine_adapter(self) -> t.Type[EngineAdapter]: def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: """The static connection kwargs for this connection""" import google.auth - from google.auth import impersonated_credentials from google.api_core import client_info, client_options + from google.auth import impersonated_credentials from google.oauth2 import credentials, service_account if self.method == BigQueryConnectionMethod.OAUTH: @@ -1516,7 +1518,7 @@ class MSSQLConnectionConfig(ConnectionConfig): tds_version: t.Optional[str] = None # Driver options - driver: t.Literal["pymssql", "pyodbc"] = "pymssql" + driver: t.Literal["pymssql", "pyodbc", "mssql-python"] = "pymssql" # PyODBC specific options driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" trust_server_certificate: t.Optional[bool] = None @@ -1543,7 +1545,11 @@ def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: driver = data.get("driver", "pymssql") # Define the mapping of driver to import module and extra name - driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")} + driver_configs = { + "pymssql": ("pymssql", "mssql"), + "pyodbc": ("pyodbc", "mssql-odbc"), + "mssql-python": ("mssql_python", "mssql-python"), + } if driver not in driver_configs: raise ValueError(f"Unsupported driver: {driver}") @@ -1589,6 +1595,18 @@ def _connection_kwargs_keys(self) -> t.Set[str]: base_keys.discard("tds_version") base_keys.discard("conn_properties") + elif self.driver == "mssql-python": + base_keys.update( + { + "trust_server_certificate", + "encrypt", + "odbc_properties", + } + ) + # Remove pymssql-specific parameters + base_keys.discard("tds_version") + base_keys.discard("conn_properties") + return base_keys @property @@ -1602,95 +1620,176 @@ def _connection_factory(self) -> t.Callable: return pymssql.connect - import pyodbc - - def connect(**kwargs: t.Any) -> t.Callable: - # Extract parameters for connection string - host = kwargs.pop("host") - port = kwargs.pop("port", 1433) - database = kwargs.pop("database", "") - user = kwargs.pop("user", None) - password = kwargs.pop("password", None) - driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") - trust_server_certificate = kwargs.pop("trust_server_certificate", False) - encrypt = kwargs.pop("encrypt", True) - login_timeout = kwargs.pop("login_timeout", 60) - - # Build connection string - conn_str_parts = [ - f"DRIVER={{{driver_name}}}", - f"SERVER={host},{port}", - ] - - if database: - conn_str_parts.append(f"DATABASE={database}") - - # Add security options - conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") - if trust_server_certificate: - conn_str_parts.append("TrustServerCertificate=YES") - - conn_str_parts.append(f"Connection Timeout={login_timeout}") - - # Standard SQL Server authentication - if user: - conn_str_parts.append(f"UID={user}") - if password: - conn_str_parts.append(f"PWD={password}") - - # Add any additional ODBC properties from the odbc_properties dictionary - if self.odbc_properties: - for key, value in self.odbc_properties.items(): - # Skip properties that we've already set above - if key.lower() in ( - "driver", - "server", - "database", - "uid", - "pwd", - "encrypt", - "trustservercertificate", - "connection timeout", - ): - continue + if self.driver == "mssql-python": + # The `mssql-python` implementation is API-compatible with + # with the `pyodbc` equivalent for documented parameters. + + if not SUPPORTS_MSSQL_PYTHON_DRIVER: + raise ConfigError("The `mssql-python` driver requires Python 3.10 or higher.") + + def connect_mssql_python(**kwargs: t.Any) -> t.Callable: + # Extract parameters for connection string + host = kwargs.pop("host") + port = kwargs.pop("port", 1433) + database = kwargs.pop("database", "") + user = kwargs.pop("user", None) + password = kwargs.pop("password", None) + authentication = kwargs.pop("authentication", None) + trust_server_certificate = kwargs.pop("trust_server_certificate", False) + encrypt = kwargs.pop("encrypt", True) + login_timeout = kwargs.pop("login_timeout", 59) + login_attempts = kwargs.pop("login_attempts", 1) # TODO: document + + # Build connection string + conn_str_parts = [ + f"Server={host},{port}", + ] + + if database: + conn_str_parts.append(f"Database={database}") + + # Add security options + conn_str_parts.append(f"Encrypt={'yes' if encrypt else 'no'}") + if trust_server_certificate: + conn_str_parts.append("TrustServerCertificate=yes") + + conn_str_parts.append(f"ConnectRetryCount={login_attempts}") + conn_str_parts.append(f"ConnectRetryInterval={min(int(login_timeout), 60)}") + + # Standard SQL Server authentication + if user: + conn_str_parts.append(f"UID={user}") + if password: + conn_str_parts.append(f"PWD={password}") + if authentication: + conn_str_parts.append(f"Authentication={authentication}") + + # Add any additional ODBC properties from the odbc_properties dictionary + if self.odbc_properties: + for key, value in self.odbc_properties.items(): + # Skip properties that we've already set above + if key.lower() in ( + "driver", + "server", + "database", + "uid", + "pwd", + "encrypt", + "trustservercertificate", + "connection timeout", + ): + continue + + # Handle boolean values properly + if isinstance(value, bool): + conn_str_parts.append(f"{key}={'yes' if value else 'no'}") + else: + conn_str_parts.append(f"{key}={value}") + + # Create the connection + conn_str = ";".join(conn_str_parts) + + import mssql_python + + conn = mssql_python.connect(conn_str, autocommit=kwargs.get("autocommit", False)) + + return conn + + return connect_mssql_python - # Handle boolean values properly - if isinstance(value, bool): - conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") - else: - conn_str_parts.append(f"{key}={value}") - - # Create the connection string - conn_str = ";".join(conn_str_parts) - - conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) - - # Set up output converters for MSSQL-specific data types - # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc - # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 - def handle_datetimeoffset(dto_value: t.Any) -> t.Any: - from datetime import datetime, timedelta, timezone - import struct - - # Unpack the DATETIMEOFFSET binary format: - # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset) - tup = struct.unpack("<6hI2h", dto_value) - return datetime( - tup[0], - tup[1], - tup[2], - tup[3], - tup[4], - tup[5], - tup[6] // 1000, - timezone(timedelta(hours=tup[7], minutes=tup[8])), - ) + if self.driver == "pyodbc": - conn.add_output_converter(-155, handle_datetimeoffset) + def connect_pyodbc(**kwargs: t.Any) -> t.Callable: + # Extract parameters for connection string + host = kwargs.pop("host") + port = kwargs.pop("port", 1433) + database = kwargs.pop("database", "") + user = kwargs.pop("user", None) + password = kwargs.pop("password", None) + driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") + trust_server_certificate = kwargs.pop("trust_server_certificate", False) + encrypt = kwargs.pop("encrypt", True) + login_timeout = kwargs.pop("login_timeout", 60) + + # Build connection string + conn_str_parts = [ + f"DRIVER={{{driver_name}}}", + f"SERVER={host},{port}", + ] + + if database: + conn_str_parts.append(f"DATABASE={database}") + + # Add security options + conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") + if trust_server_certificate: + conn_str_parts.append("TrustServerCertificate=YES") + + conn_str_parts.append(f"Connection Timeout={login_timeout}") + + # Standard SQL Server authentication + if user: + conn_str_parts.append(f"UID={user}") + if password: + conn_str_parts.append(f"PWD={password}") + + # Add any additional ODBC properties from the odbc_properties dictionary + if self.odbc_properties: + for key, value in self.odbc_properties.items(): + # Skip properties that we've already set above + if key.lower() in ( + "driver", + "server", + "database", + "uid", + "pwd", + "encrypt", + "trustservercertificate", + "connection timeout", + ): + continue + + # Handle boolean values properly + if isinstance(value, bool): + conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") + else: + conn_str_parts.append(f"{key}={value}") + + # Create the connection + conn_str = ";".join(conn_str_parts) + + import pyodbc + + conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) + + # Set up output converters for MSSQL-specific data types + # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc + # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 + def handle_datetimeoffset_pyodbc(dto_value: t.Any) -> t.Any: + import struct + from datetime import datetime, timedelta, timezone + + # Unpack the DATETIMEOFFSET binary format: + # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset) + tup = struct.unpack("<6hI2h", dto_value) + return datetime( + tup[0], + tup[1], + tup[2], + tup[3], + tup[4], + tup[5], + tup[6] // 1000, + timezone(timedelta(hours=tup[7], minutes=tup[8])), + ) - return conn + conn.add_output_converter(-155, handle_datetimeoffset_pyodbc) - return connect + return conn + + return connect_pyodbc + + raise ValueError(f"Unsupported driver: {self.driver}") @property def _extra_engine_config(self) -> t.Dict[str, t.Any]: @@ -1718,7 +1817,7 @@ class FabricConnectionConfig(MSSQLConnectionConfig): DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric" # type: ignore DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric" # type: ignore DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17 # type: ignore - driver: t.Literal["pyodbc"] = "pyodbc" + driver: t.Literal["pyodbc", "mssql-python"] = "pyodbc" workspace_id: str tenant_id: str autocommit: t.Optional[bool] = True @@ -2124,9 +2223,10 @@ def _engine_adapter(self) -> t.Type[EngineAdapter]: @property def _connection_factory(self) -> t.Callable: + from functools import partial + from clickhouse_connect.dbapi import connect # type: ignore from clickhouse_connect.driver import httputil # type: ignore - from functools import partial pool_manager_options: t.Dict[str, t.Any] = dict( # Match the maxsize to the number of concurrent tasks diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index dd979a2551..855a92d39f 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1,31 +1,32 @@ import base64 import re import typing as t +from unittest.mock import MagicMock, patch import pytest from _pytest.fixtures import FixtureRequest from sqlglot import exp -from unittest.mock import patch, MagicMock from sqlmesh.core.config.connection import ( + INIT_DISPLAY_INFO_TO_TYPE, + SUPPORTS_MSSQL_PYTHON_DRIVER, + AthenaConnectionConfig, BigQueryConnectionConfig, ClickhouseConnectionConfig, ConnectionConfig, DatabricksConnectionConfig, DuckDBAttachOptions, - FabricConnectionConfig, DuckDBConnectionConfig, + FabricConnectionConfig, GCPPostgresConnectionConfig, MotherDuckConnectionConfig, + MSSQLConnectionConfig, MySQLConnectionConfig, PostgresConnectionConfig, SnowflakeConnectionConfig, TrinoAuthenticationMethod, - AthenaConnectionConfig, - MSSQLConnectionConfig, _connection_config_validator, _get_engine_import_validator, - INIT_DISPLAY_INFO_TO_TYPE, ) from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import PydanticModel @@ -968,9 +969,10 @@ def test_motherduck_attach_options(): def test_duckdb_multithreaded_connection_factory(make_config): + from threading import Thread + from sqlmesh.core.engine_adapter import DuckDBEngineAdapter from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool - from threading import Thread config = make_config(type="duckdb") @@ -1462,6 +1464,13 @@ def test_mssql_engine_import_validator(): mock_import.side_effect = ImportError("No module named 'pyodbc'") MSSQLConnectionConfig(host="localhost", driver="pyodbc") + # Test MSSQL Python driver suggests mssql-python extra when import fails + if SUPPORTS_MSSQL_PYTHON_DRIVER: + with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql-python\]\""): + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'mssql_python'") + MSSQLConnectionConfig(host="localhost", driver="mssql-python") + # Test PyMSSQL driver suggests mssql extra when import fails with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql\]\""): with patch("importlib.import_module") as mock_import: @@ -1493,6 +1502,14 @@ def test_mssql_connection_config_parameter_validation(make_config): assert isinstance(config, MSSQLConnectionConfig) assert config.driver == "pyodbc" + # Test explicit mssql-python driver + if SUPPORTS_MSSQL_PYTHON_DRIVER: + config = make_config( + type="mssql", host="localhost", driver="mssql-python", check_import=False + ) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver == "mssql-python" + # Test explicit pymssql driver config = make_config(type="mssql", host="localhost", driver="pymssql", check_import=False) assert isinstance(config, MSSQLConnectionConfig) @@ -1515,6 +1532,22 @@ def test_mssql_connection_config_parameter_validation(make_config): assert config.encrypt is False assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} + # Test mssql-python specific parameters + if SUPPORTS_MSSQL_PYTHON_DRIVER: + config = make_config( + type="mssql", + host="localhost", + driver="mssql-python", + trust_server_certificate=True, + encrypt=False, + odbc_properties={"Authentication": "ActiveDirectoryServicePrincipal"}, + check_import=False, + ) + assert isinstance(config, MSSQLConnectionConfig) + assert config.trust_server_certificate is True + assert config.encrypt is False + assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} + # Test pymssql specific parameters config = make_config( type="mssql", @@ -1575,7 +1608,33 @@ def test_mssql_connection_kwargs_keys(): assert "tds_version" not in pyodbc_keys assert "conn_properties" not in pyodbc_keys - + # Test mssql-python driver keys + if SUPPORTS_MSSQL_PYTHON_DRIVER: + config = MSSQLConnectionConfig(host="localhost", driver="mssql-python", check_import=False) + mssql_python_keys = config._connection_kwargs_keys + expected_mssql_python_keys = { + "password", + "user", + "database", + "host", + "timeout", + "login_timeout", + "charset", + "appname", + "port", + "autocommit", + "trust_server_certificate", + "encrypt", + "odbc_properties", + } + assert mssql_python_keys == expected_mssql_python_keys + + # Verify mssql-python keys don't include pymssql-specific parameters + assert "tds_version" not in mssql_python_keys + assert "conn_properties" not in mssql_python_keys + + +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") def test_mssql_pyodbc_connection_string_generation(): """Test pyodbc.connect gets invoked with the correct ODBC connection string.""" with patch("pyodbc.connect") as mock_pyodbc_connect: @@ -1663,6 +1722,7 @@ def test_mssql_pyodbc_connection_string_with_odbc_properties(): assert conn_str.count("TrustServerCertificate") == 1 +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") def test_mssql_pyodbc_connection_string_minimal(): """Test pyodbc connection string with minimal configuration.""" with patch("pyodbc.connect") as mock_pyodbc_connect: @@ -1689,6 +1749,121 @@ def test_mssql_pyodbc_connection_string_minimal(): assert mock_pyodbc_connect.call_args[1]["autocommit"] is True +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") +def test_mssql_mssql_python_connection_string_generation(): + """Test mssql_python.connect gets invoked with the correct ODBC connection string.""" + with patch("mssql_python.connect") as mock_mssql_python_connect: + # Mock the return value to have the methods we need + mock_connection = mock_mssql_python_connect.return_value + + # Create a mssql-python config + config = MSSQLConnectionConfig( + host="testserver.database.windows.net", + port=1433, + database="testdb", + user="testuser", + password="testpass", + driver="mssql-python", + trust_server_certificate=True, + encrypt=True, + login_timeout=30, + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify mssql_python.connect was called with the correct connection string + mock_mssql_python_connect.assert_called_once() + call_args = mock_mssql_python_connect.call_args + + # Check the connection string (first argument) + conn_str = call_args[0][0] + expected_parts = [ + "Server=testserver.database.windows.net,1433", + "Database=testdb", + "Encrypt=yes", + "TrustServerCertificate=yes", + "ConnectRetryCount=1", + "ConnectRetryInterval=30", + "UID=testuser", + "PWD=testpass", + ] + + for part in expected_parts: + assert part in conn_str + + # Check autocommit parameter + assert call_args[1]["autocommit"] is False + + +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") +def test_mssql_mssql_python_connection_string_with_odbc_properties(): + """Test mssql-python connection string includes custom ODBC properties.""" + with patch("mssql_python.connect") as mock_mssql_python_connect: + # Create a mssql-python config with custom ODBC properties + config = MSSQLConnectionConfig( + host="testserver.database.windows.net", + database="testdb", + user="client-id", + password="client-secret", + driver="mssql-python", + odbc_properties={ + "Authentication": "ActiveDirectoryServicePrincipal", + "ClientCertificate": "/path/to/cert.pem", + "TrustServerCertificate": "NO", # This should be ignored since we set it explicitly + }, + trust_server_certificate=True, # This should take precedence + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify mssql_python.connect was called + mock_mssql_python_connect.assert_called_once() + conn_str = mock_mssql_python_connect.call_args[0][0] + + # Check that custom ODBC properties are included + assert "Authentication=ActiveDirectoryServicePrincipal" in conn_str + assert "ClientCertificate=/path/to/cert.pem" in conn_str + + # Verify that explicit trust_server_certificate takes precedence + assert "TrustServerCertificate=yes" in conn_str + + # Should not have the conflicting property from odbc_properties + assert conn_str.count("TrustServerCertificate") == 1 + + +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") +def test_mssql_mssql_python_connection_string_minimal(): + """Test mssql-python connection string with minimal configuration.""" + with patch("mssql_python.connect") as mock_mssql_python_connect: + config = MSSQLConnectionConfig( + host="localhost", + driver="mssql-python", + autocommit=True, + check_import=False, + ) + + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + mock_mssql_python_connect.assert_called_once() + conn_str = mock_mssql_python_connect.call_args[0][0] + + # Check basic required parts + assert "Server=localhost,1433" in conn_str + assert "Encrypt=yes" in conn_str # Default encrypt=True + assert "ConnectRetryCount=1" in conn_str # Default timeout + assert "ConnectRetryInterval=60" in conn_str # Default timeout + + # Check autocommit parameter + assert mock_mssql_python_connect.call_args[1]["autocommit"] is True + + def test_mssql_pymssql_connection_factory(): """Test pymssql connection factory returns correct function.""" # Mock the import of pymssql at the module level @@ -1718,8 +1893,8 @@ def test_mssql_pymssql_connection_factory(): def test_mssql_pyodbc_connection_datetimeoffset_handling(): """Test that the MSSQL pyodbc connection properly handles DATETIMEOFFSET conversion.""" - from datetime import datetime, timezone, timedelta import struct + from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch with patch("pyodbc.connect") as mock_pyodbc_connect: @@ -1790,8 +1965,8 @@ def mock_add_output_converter(sql_type, converter_func): def test_mssql_pyodbc_connection_negative_timezone_offset(): """Test DATETIMEOFFSET handling with negative timezone offset at connection level.""" - from datetime import datetime, timezone, timedelta import struct + from datetime import datetime, timedelta, timezone from unittest.mock import Mock, patch with patch("pyodbc.connect") as mock_pyodbc_connect: @@ -1842,6 +2017,134 @@ def mock_add_output_converter(sql_type, converter_func): assert result.tzinfo == timezone(timedelta(hours=-8)) +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") +def test_mssql_mssql_python_connection_datetimeoffset_handling(): + """Test that the MSSQL mssql-python connection properly handles DATETIMEOFFSET conversion.""" + import struct + from datetime import datetime, timedelta, timezone + from unittest.mock import Mock, patch + + with patch("mssql_python.connect") as mock_mssql_python_connect: + # Track calls to add_output_converter + converter_calls = [] + + def mock_add_output_converter(sql_type, converter_func): + converter_calls.append((sql_type, converter_func)) + + # Create a mock connection that will be returned by mssql_python.connect + mock_connection = Mock() + mock_connection.add_output_converter = mock_add_output_converter + mock_mssql_python_connect.return_value = mock_connection + + config = MSSQLConnectionConfig( + host="localhost", + driver="mssql-python", # DATETIMEOFFSET handling is mssql-python-specific + check_import=False, + ) + + # Get the connection factory and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify that add_output_converter was called for SQL type -155 (DATETIMEOFFSET) + assert len(converter_calls) == 1 + sql_type, converter_func = converter_calls[0] + assert sql_type == -155 + + # Test the converter function with actual DATETIMEOFFSET binary data + # Create a test DATETIMEOFFSET value: 2023-12-25 15:30:45.123456789 +05:30 + year, month, day = 2023, 12, 25 + hour, minute, second = 15, 30, 45 + nanoseconds = 123456789 + tz_hour_offset, tz_minute_offset = 5, 30 + + # Pack the binary data according to the DATETIMEOFFSET format + binary_data = struct.pack( + "<6hI2h", + year, + month, + day, + hour, + minute, + second, + nanoseconds, + tz_hour_offset, + tz_minute_offset, + ) + + # Convert using the registered converter + result = converter_func(binary_data) + + # Verify the result + expected_dt = datetime( + 2023, + 12, + 25, + 15, + 30, + 45, + 123456, # microseconds = nanoseconds // 1000 + timezone(timedelta(hours=5, minutes=30)), + ) + assert result == expected_dt + assert result.tzinfo == timezone(timedelta(hours=5, minutes=30)) + + +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") +def test_mssql_mssql_python_connection_negative_timezone_offset(): + """Test DATETIMEOFFSET handling with negative timezone offset at connection level.""" + import struct + from datetime import datetime, timedelta, timezone + from unittest.mock import Mock, patch + + with patch("mssql_python.connect") as mock_mssql_python_connect: + converter_calls = [] + + def mock_add_output_converter(sql_type, converter_func): + converter_calls.append((sql_type, converter_func)) + + mock_connection = Mock() + mock_connection.add_output_converter = mock_add_output_converter + mock_mssql_python_connect.return_value = mock_connection + + config = MSSQLConnectionConfig( + host="localhost", + driver="mssql-python", # DATETIMEOFFSET handling is mssql-python-specific + check_import=False, + ) + + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Get the converter function + _, converter_func = converter_calls[0] + + # Test with negative timezone offset: 2023-01-01 12:00:00.0 -08:00 + year, month, day = 2023, 1, 1 + hour, minute, second = 12, 0, 0 + nanoseconds = 0 + tz_hour_offset, tz_minute_offset = -8, 0 + + binary_data = struct.pack( + "<6hI2h", + year, + month, + day, + hour, + minute, + second, + nanoseconds, + tz_hour_offset, + tz_minute_offset, + ) + + result = converter_func(binary_data) + + expected_dt = datetime(2023, 1, 1, 12, 0, 0, 0, timezone(timedelta(hours=-8, minutes=0))) + assert result == expected_dt + assert result.tzinfo == timezone(timedelta(hours=-8)) + + def test_fabric_connection_config_defaults(make_config): """Test Fabric connection config defaults to pyodbc and autocommit=True.""" config = make_config( @@ -1861,8 +2164,8 @@ def test_fabric_connection_config_defaults(make_config): assert isinstance(config.create_engine_adapter(), FabricEngineAdapter) -def test_fabric_connection_config_parameter_validation(make_config): - """Test Fabric connection config parameter validation.""" +def test_fabric_pyodbc_connection_config_parameter_validation(make_config): + """Test Fabric pyodbc connection config parameter validation.""" # Test that FabricConnectionConfig correctly handles pyodbc-specific parameters. config = make_config( type="fabric", @@ -1883,7 +2186,33 @@ def test_fabric_connection_config_parameter_validation(make_config): assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} # Test that specifying a different driver for Fabric raises an error - with pytest.raises(ConfigError, match=r"Input should be 'pyodbc'"): + with pytest.raises(ConfigError, match=r"Input should be 'pyodbc' or 'mssql-python'"): + make_config(type="fabric", host="localhost", driver="pymssql", check_import=False) + + +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") +def test_fabric_mssql_python_connection_config_parameter_validation(make_config): + """Test Fabric mssql-python connection config parameter validation.""" + # Test that FabricConnectionConfig correctly handles mssql-python-specific parameters. + config = make_config( + type="fabric", + host="localhost", + driver="mssql-python", + trust_server_certificate=True, + encrypt=False, + odbc_properties={"Authentication": "ActiveDirectoryServicePrincipal"}, + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + assert isinstance(config, FabricConnectionConfig) + assert config.driver == "mssql-python" # Driver is fixed to mssql-python + assert config.trust_server_certificate is True + assert config.encrypt is False + assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} + + # Test that specifying a different driver for Fabric raises an error + with pytest.raises(ConfigError, match=r"Input should be 'pyodbc' or 'mssql-python'"): make_config(type="fabric", host="localhost", driver="pymssql", check_import=False) @@ -1934,6 +2263,54 @@ def test_fabric_pyodbc_connection_string_generation(): assert call_args[1]["autocommit"] is True +@pytest.mark.xfail(not SUPPORTS_MSSQL_PYTHON_DRIVER, reason="mssql-python driver not supported") +def test_fabric_mssql_python_connection_string_generation(): + """Test that the Fabric mssql-python connection gets invoked with the correct connection string.""" + with patch("mssql_python.connect") as mock_mssql_python_connect: + # Create a Fabric config + config = FabricConnectionConfig( + driver="mssql-python", + host="testserver.datawarehouse.fabric.microsoft.com", + port=1433, + database="testdb", + user="testuser", + password="testpass", + trust_server_certificate=True, + encrypt=True, + login_timeout=30, + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify mssql_python.connect was called with the correct connection string + mock_mssql_python_connect.assert_called_once() + call_args = mock_mssql_python_connect.call_args + + # Check the connection string (first argument) + conn_str = call_args[0][0] + expected_parts = [ + "Server=testserver.datawarehouse.fabric.microsoft.com,1433", + "Database=testdb", + "Encrypt=yes", + "TrustServerCertificate=yes", + "ConnectRetryCount=1", + "ConnectRetryInterval=30", + "UID=testuser", + "PWD=testpass", + ] + + for part in expected_parts: + assert part in conn_str + + # Check autocommit parameter, should default to True for Fabric + assert call_args[1]["autocommit"] is True + + def test_schema_differ_overrides(make_config) -> None: default_config = make_config(type="duckdb") assert default_config.schema_differ_overrides is None