diff --git a/extensions/positron-python/python_files/posit/pinned-test-requirements.txt b/extensions/positron-python/python_files/posit/pinned-test-requirements.txt index 747454810b0..2f9fbc216fd 100644 --- a/extensions/positron-python/python_files/posit/pinned-test-requirements.txt +++ b/extensions/positron-python/python_files/posit/pinned-test-requirements.txt @@ -43,6 +43,8 @@ polars==1.35.2 polars[timezone]==1.35.2; sys_platform == 'win32' pyarrow==21.0.0; python_version == '3.9' pyarrow==22.0.0; python_version > '3.10' +pymssql==2.3.10; python_version < '3.14' +pyodbc==5.3.0; python_version < '3.14' pytest==8.4.2; python_version == '3.9' pytest==9.0.1; python_version >= '3.10' pytest-asyncio==1.2.0; python_version == '3.9' diff --git a/extensions/positron-python/python_files/posit/positron/connections.py b/extensions/positron-python/python_files/posit/positron/connections.py index c8b29cb9d64..9da5f4a13ce 100644 --- a/extensions/positron-python/python_files/posit/positron/connections.py +++ b/extensions/positron-python/python_files/posit/positron/connections.py @@ -44,6 +44,25 @@ logger = logging.getLogger(__name__) +def _is_pyodbc_sqlserver(conn: Any) -> bool: + """Return True if `conn` is a pyodbc connection to SQL Server.""" + if not safe_isinstance(conn, "pyodbc", "Connection"): + return False + + try: + import pyodbc + except ImportError: + return False + + try: + dbms_name = str(conn.getinfo(pyodbc.SQL_DBMS_NAME)) + except Exception: + return False + + upper_name = dbms_name.upper() + return "SQL SERVER" in upper_name or "AZURE SQL" in upper_name + + class ConnectionWarning(UserWarning): """ Warning raised when there are issues in the Connections Pane relevant to the user. @@ -322,6 +341,8 @@ def _wrap_connection(self, obj: Any) -> Connection: return GoogleBigQueryConnection(obj) elif safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection"): return SnowflakeConnection(obj) + elif _is_pyodbc_sqlserver(obj) or safe_isinstance(obj, "pymssql", "Connection"): + return SQLServerConnection(obj) elif safe_isinstance(obj, "databricks.sql.client", "Connection"): return DatabricksConnection(obj) else: @@ -343,6 +364,8 @@ def object_is_supported(self, obj: Any) -> bool: ) or safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection") or safe_isinstance(obj, "databricks.sql.client", "Connection") + or _is_pyodbc_sqlserver(obj) + or safe_isinstance(obj, "pymssql", "Connection") ) except Exception as err: logger.error(f"Error checking supported {err}") @@ -1159,6 +1182,247 @@ def _make_code(self): return code +class SQLServerConnection(Connection): + """Support for SQL Server connections to databases.""" + + def __init__(self, conn: Any): + self.conn = conn + + try: + self.host = self._fetch_one_value("SELECT @@SERVERNAME") + except Exception: + self.host = "" + + self.database: str | None = None + with contextlib.suppress(Exception): + db_name = self._fetch_one_value("SELECT DB_NAME()") + self.database = str(db_name) if db_name not in (None, "") else None + + self.type = "SQLServer" + (" (pyodbc)" if self._is_pyodbc() else " (pymssql)") + self.display_name = f"{self.type} - {self.host}" + + self.code = self._make_code() + self.icon = "" + + def disconnect(self): + with contextlib.suppress(Exception): + self.conn.close() + + def list_object_types(self): + return { + "database": ConnectionObjectInfo({"contains": None, "icon": None}), + "schema": ConnectionObjectInfo({"contains": None, "icon": None}), + "table": ConnectionObjectInfo({"contains": "data", "icon": None}), + "view": ConnectionObjectInfo({"contains": "data", "icon": None}), + } + + def list_objects(self, path: list[ObjectSchema]): + if len(path) == 0: + rows = self._execute("SELECT name FROM sys.databases ORDER BY name;") + return [ConnectionObject({"name": row[0], "kind": "database"}) for row in rows] + + if len(path) == 1: + database = path[0] + if database.kind != "database": + raise ValueError( + f"Invalid path. Expected it to include a database, but got '{database.kind}'. Path: {path}" + ) + + rows = self._execute( + f"SELECT name FROM {self._qualify(database.name, 'sys', 'schemas')} ORDER BY name;" + ) + return [ConnectionObject({"name": row[0], "kind": "schema"}) for row in rows] + + if len(path) == 2: + database, schema = path + if database.kind != "database" or schema.kind != "schema": + raise ValueError( + "Path must include a database and schema in this order. " + f"Got database.kind={database.kind}, schema.kind={schema.kind}. Path: {path}" + ) + + rows = self._execute( + f""" + SELECT TABLE_NAME, TABLE_TYPE + FROM {self._qualify(database.name, "INFORMATION_SCHEMA", "TABLES")} + WHERE TABLE_SCHEMA = {self._quote_literal(schema.name)} + ORDER BY TABLE_NAME; + """ + ) + + objects: list[ConnectionObject] = [] + for table_name, table_type in rows: + kind = "view" if "VIEW" in str(table_type).upper() else "table" + objects.append(ConnectionObject({"name": table_name, "kind": kind})) + return objects + + raise ValueError(f"Path length must be at most 2, but got {len(path)}. Path: {path}") + + def list_fields(self, path: list[ObjectSchema]): + if len(path) != 3: + raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}") + + database, schema, table = path + if ( + database.kind != "database" + or schema.kind != "schema" + or table.kind not in ["table", "view"] + ): + raise ValueError( + "Path must include a database, schema and table/view in this order. " + f"Got database.kind={database.kind}, schema.kind={schema.kind}, table.kind={table.kind}. " + f"Path: {path}" + ) + + rows = self._execute( + f""" + SELECT COLUMN_NAME, DATA_TYPE + FROM {self._qualify(database.name, "INFORMATION_SCHEMA", "COLUMNS")} + WHERE TABLE_SCHEMA = {self._quote_literal(schema.name)} AND TABLE_NAME = {self._quote_literal(table.name)} + ORDER BY ORDINAL_POSITION; + """ + ) + + return [ConnectionObjectFields({"name": name, "dtype": dtype}) for name, dtype in rows] + + def preview_object(self, path: list[ObjectSchema], var_name: str | None = None): + try: + import pandas as pd + except ImportError as e: + raise ModuleNotFoundError("Pandas is required for previewing SQL Server tables.") from e + + if len(path) != 3: + raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}") + + database, schema, table = path + if ( + database.kind != "database" + or schema.kind != "schema" + or table.kind not in ["table", "view"] + ): + raise ValueError( + "Path must include a database, schema and table/view in this order. " + f"Got database.kind={database.kind}, schema.kind={schema.kind}, table.kind={table.kind}. " + f"Path: {path}" + ) + + qualified_name = self._qualify(database.name, schema.name, table.name) + query = f"SELECT TOP 1000 * FROM {qualified_name};" + cursor = self.conn.cursor() + try: + cursor.execute(query) + rows = cursor.fetchall() + cols = [c[0] for c in cursor.description or []] + finally: + with contextlib.suppress(Exception): + cursor.close() + + if self._is_pyodbc(): + # pyodbc returns rows as pyodbc.Row, which pandas cannot handle directly + rows = [tuple(row) for row in rows] + + preview_df = pd.DataFrame(rows, columns=cols) + + var_name = var_name or "conn" + sql_string = ( + f"# {table.name} = pd.read_sql({query!r}, {var_name}) " + f"# where {var_name} is your connection variable" + ) + return preview_df, sql_string + + def _execute(self, sql: str) -> list[tuple[Any, ...]]: + cursor = self.conn.cursor() + try: + cursor.execute(sql) + return cursor.fetchall() + finally: + with contextlib.suppress(Exception): + cursor.close() + + def _fetch_one_value(self, sql: str) -> Any: + rows = self._execute(sql) + return rows[0][0] if rows else None + + def _qualify(self, *parts: str) -> str: + return ".".join(self._quote_identifier(part) for part in parts) + + def _quote_identifier(self, identifier: str) -> str: + return f"[{identifier.replace(']', ']]')}]" + + def _quote_literal(self, value: str) -> str: + return "'" + value.replace("'", "''") + "'" + + def _is_pyodbc(self) -> bool: + return safe_isinstance(self.conn, "pyodbc", "Connection") + + def _pyodbc_getinfo(self, attr: str) -> str | None: + try: + import pyodbc # type: ignore + except ImportError: + return None + + constant = getattr(pyodbc, attr, None) + if constant is None: + return None + + if not safe_isinstance(self.conn, "pyodbc", "Connection"): + return None + + try: + value = self.conn.getinfo(constant) + return str(value) if value is not None else None + except Exception: + return None + + def _pyodbc_connection_string(self) -> str | None: + driver = self._pyodbc_getinfo("SQL_DRIVER_NAME") + server = self._pyodbc_getinfo("SQL_SERVER_NAME") + if server is None and self.host != "": + server = self.host + database = self._pyodbc_getinfo("SQL_DATABASE_NAME") or self.database + + if driver is None and server is None and database is None: + return None + + parts = [] + if driver: + parts.append(f"DRIVER={{{driver}}}") + if server: + parts.append(f"SERVER={server}") + if database: + parts.append(f"DATABASE={database}") + parts.append("Trusted_Connection=yes") + return ";".join(parts) + ";" + + def _default_connection_string(self) -> str: + server = self.host if self.host != "" else "" + database = self.database or "" + return ( + "DRIVER={ODBC Driver 18 for SQL Server};" + f"SERVER={server};" + f"DATABASE={database};" + "Trusted_Connection=yes;" + ) + + def _make_code(self): + if self._is_pyodbc(): + conn_str = self._pyodbc_connection_string() or self._default_connection_string() + return f"import pyodbc\nconn = pyodbc.connect({conn_str!r})\n%connection_show conn\n" + + server = self.host if self.host != "" else "" + database = self.database or "" + return ( + "import pymssql\n" + "conn = pymssql.connect(\n" + f" server={server!r},\n" + f" database={database!r},\n" + " user='', # TODO: Replace with your username\n" + " password='', # TODO: Replace with your password\n" + ")\n" + "%connection_show conn\n" + ) + + class DatabricksConnection(Connection): """Support for Databricks connections to databases.""" diff --git a/extensions/positron-python/python_files/posit/positron/tests/test_connections.py b/extensions/positron-python/python_files/posit/positron/tests/test_connections.py index 7b8646b2040..9e425830282 100644 --- a/extensions/positron-python/python_files/posit/positron/tests/test_connections.py +++ b/extensions/positron-python/python_files/posit/positron/tests/test_connections.py @@ -60,6 +60,65 @@ TARGET_NAME = "positron.connections" +SQLSERVER_ENDPOINT = os.environ.get("SQLSERVER_ENDPOINT") +SQLSERVER_PASSWORD = os.environ.get("SQLSERVER_PASSWORD") +SQLSERVER_USERNAME = os.environ.get("SQLSERVER_USERNAME") +SQLSERVER_DATABASE = os.environ.get("SQLSERVER_DATABASE") +SQLSERVER_ODBC_DRIVER = os.environ.get("SQLSERVER_ODBC_DRIVER", "ODBC Driver 18 for SQL Server") + +try: + import pyodbc + + HAS_SQLSERVER_PYODBC = SQLSERVER_ENDPOINT is not None and SQLSERVER_PASSWORD is not None +except ImportError: + pyodbc = None # type: ignore[assignment] + HAS_SQLSERVER_PYODBC = False + +try: + import pymssql + + HAS_SQLSERVER_PYMSSQL = SQLSERVER_ENDPOINT is not None and SQLSERVER_PASSWORD is not None +except ImportError: + pymssql = None # type: ignore[assignment] + HAS_SQLSERVER_PYMSSQL = False + + +def get_sqlserver_pyodbc_connection(): + if not HAS_SQLSERVER_PYODBC: + pytest.skip("SQL Server pyodbc connection not available") + + assert pyodbc is not None + if SQLSERVER_ENDPOINT is None or SQLSERVER_PASSWORD is None: + pytest.skip("SQL Server endpoint/password not configured") + conn_str = ( + f"DRIVER={{{SQLSERVER_ODBC_DRIVER}}};" + f"SERVER={SQLSERVER_ENDPOINT};" + f"DATABASE={SQLSERVER_DATABASE};" + f"UID={SQLSERVER_USERNAME};" + f"PWD={SQLSERVER_PASSWORD};" + "Encrypt=yes;" + "TrustServerCertificate=yes;" + ) + return pyodbc.connect(conn_str) + + +def get_sqlserver_pymssql_connection(): + if not HAS_SQLSERVER_PYMSSQL: + pytest.skip("SQL Server pymssql connection not available") + + assert pymssql is not None + if SQLSERVER_ENDPOINT is None or SQLSERVER_PASSWORD is None: + pytest.skip("SQL Server endpoint/password not configured") + kwargs = { + "server": SQLSERVER_ENDPOINT, + "user": SQLSERVER_USERNAME, + "password": SQLSERVER_PASSWORD, + "database": SQLSERVER_DATABASE, + } + + return pymssql.connect(**kwargs) + + def add_default_data(execute): execute("CREATE TABLE movie(title TEXT, year INTEGER, score NUMERIC)") execute("INSERT INTO movie VALUES('The Shawshank Redemption', 1994, 9.3)") @@ -436,6 +495,7 @@ def test_list_fields(self, connections_service: ConnectionsService): msg = _make_msg(params={"path": path}, method="list_fields", comm_id=comm_id) dummy_comm.handle_msg(msg) result = dummy_comm.messages[0]["data"]["result"] + assert len(result) > 0 field_names = [field["name"] for field in result] for expected_field in ["name", "gender", "state", "year", "number"]: assert expected_field in field_names @@ -574,6 +634,135 @@ def test_preview_object(self, connections_service: ConnectionsService): assert result is None +class _SQLServerConnectionsTestBase: + schema_name = "dbo" + table_name = "flights" + + def _get_connection(self): + raise NotImplementedError + + def _get_database_name(self, conn): + cursor = conn.cursor() + cursor.execute("SELECT DB_NAME()") + database_name = cursor.fetchone()[0] + cursor.close() + return database_name + + def _resolve_path(self, database_name: str, kind: str): + if kind == "root": + return [] + if kind == "database": + return [{"kind": "database", "name": database_name}] + if kind == "schema": + return [ + {"kind": "database", "name": database_name}, + {"kind": "schema", "name": self.schema_name}, + ] + if kind == "table": + return [ + {"kind": "database", "name": database_name}, + {"kind": "schema", "name": self.schema_name}, + {"kind": "table", "name": self.table_name}, + ] + raise ValueError(f"Unknown path kind: {kind}") + + def _open_comm(self, connections_service: ConnectionsService): + con = self._get_connection() + database_name = self._get_database_name(con) + comm_id = connections_service.register_connection(con) + dummy_comm = DummyComm(TARGET_NAME, comm_id=comm_id) + connections_service.on_comm_open(dummy_comm) + dummy_comm.messages.clear() + return dummy_comm, comm_id, database_name + + def test_register_connection(self, connections_service: ConnectionsService): + con = self._get_connection() + comm_id = connections_service.register_connection(con) + assert comm_id in connections_service.comms + + @pytest.mark.parametrize( + "path_kind", + ["root", "database", "schema", "table"], + ) + def test_contains_data(self, connections_service: ConnectionsService, path_kind: str): + dummy_comm, comm_id, database_name = self._open_comm(connections_service) + path = self._resolve_path(database_name, path_kind) + + msg = _make_msg(params={"path": path}, method="contains_data", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + assert result is path_kind == "table" + + @pytest.mark.parametrize( + ("path_kind", "expected"), + [ + ("database", ""), + ("schema", ""), + ("table", ""), + ], + ) + def test_get_icon(self, connections_service: ConnectionsService, path_kind: str, expected: str): + dummy_comm, comm_id, database_name = self._open_comm(connections_service) + path = self._resolve_path(database_name, path_kind) + + msg = _make_msg(params={"path": path}, method="get_icon", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + assert result == expected + + @pytest.mark.parametrize("path_kind", ["root", "database", "schema"]) + def test_list_objects(self, connections_service: ConnectionsService, path_kind: str): + dummy_comm, comm_id, database_name = self._open_comm(connections_service) + path = self._resolve_path(database_name, path_kind) + + msg = _make_msg(params={"path": path}, method="list_objects", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + names = [item["name"] for item in result] + if path_kind == "root": + assert database_name in names + elif path_kind == "database": + assert self.schema_name in names + else: + assert self.table_name in names + + def test_list_fields(self, connections_service: ConnectionsService): + dummy_comm, comm_id, database_name = self._open_comm(connections_service) + path = self._resolve_path(database_name, "table") + + msg = _make_msg(params={"path": path}, method="list_fields", comm_id=comm_id) + dummy_comm.handle_msg(msg) + result = dummy_comm.messages[0]["data"]["result"] + assert len(result) > 0 + field_names = [field["name"] for field in result] + assert "air_time" in field_names + assert "carrier" in field_names + assert "distance" in field_names + + def test_preview_object(self, connections_service: ConnectionsService): + pytest.importorskip("pandas", reason="pandas required for SQL Server preview") + dummy_comm, comm_id, database_name = self._open_comm(connections_service) + path = self._resolve_path(database_name, "table") + + msg = _make_msg(params={"path": path}, method="preview_object", comm_id=comm_id) + dummy_comm.handle_msg(msg) + connections_service._kernel.data_explorer_service.shutdown() # noqa: SLF001 + result = dummy_comm.messages[0]["data"]["result"] + assert result is None + + +@pytest.mark.skipif(not HAS_SQLSERVER_PYODBC, reason="SQL Server pyodbc connection not available") +class TestSQLServerPyodbcConnectionsService(_SQLServerConnectionsTestBase): + def _get_connection(self): + return get_sqlserver_pyodbc_connection() + + +@pytest.mark.skipif(not HAS_SQLSERVER_PYMSSQL, reason="SQL Server pymssql connection not available") +class TestSQLServerPymssqlConnectionsService(_SQLServerConnectionsTestBase): + def _get_connection(self): + return get_sqlserver_pymssql_connection() + + @pytest.mark.skipif(not HAS_SNOWFLAKE, reason="Snowflake not available") class TestSnowflakeConnectionsService: DATABASE_NAME = "POSITRON_CONNECTIONS_PANE_TESTS"