Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ polars==1.35.2
polars[timezone]==1.35.2; sys_platform == 'win32'
pyarrow==21.0.0; python_version == '3.9'
pyarrow==22.0.0; python_version > '3.10'
pymssql==2.3.10; python_version < '3.14'
pyodbc==5.3.0; python_version < '3.14'
pytest==8.4.2; python_version == '3.9'
pytest==9.0.1; python_version >= '3.10'
pytest-asyncio==1.2.0; python_version == '3.9'
Expand Down
264 changes: 264 additions & 0 deletions extensions/positron-python/python_files/posit/positron/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@
logger = logging.getLogger(__name__)


def _is_pyodbc_sqlserver(conn: Any) -> bool:
"""Return True if `conn` is a pyodbc connection to SQL Server."""
if not safe_isinstance(conn, "pyodbc", "Connection"):
return False

try:
import pyodbc
except ImportError:
return False

try:
dbms_name = str(conn.getinfo(pyodbc.SQL_DBMS_NAME))
except Exception:
return False

upper_name = dbms_name.upper()
return "SQL SERVER" in upper_name or "AZURE SQL" in upper_name


class ConnectionWarning(UserWarning):
"""
Warning raised when there are issues in the Connections Pane relevant to the user.
Expand Down Expand Up @@ -322,6 +341,8 @@ def _wrap_connection(self, obj: Any) -> Connection:
return GoogleBigQueryConnection(obj)
elif safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection"):
return SnowflakeConnection(obj)
elif _is_pyodbc_sqlserver(obj) or safe_isinstance(obj, "pymssql", "Connection"):
return SQLServerConnection(obj)
elif safe_isinstance(obj, "databricks.sql.client", "Connection"):
return DatabricksConnection(obj)
else:
Expand All @@ -343,6 +364,8 @@ def object_is_supported(self, obj: Any) -> bool:
)
or safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection")
or safe_isinstance(obj, "databricks.sql.client", "Connection")
or _is_pyodbc_sqlserver(obj)
or safe_isinstance(obj, "pymssql", "Connection")
)
except Exception as err:
logger.error(f"Error checking supported {err}")
Expand Down Expand Up @@ -1159,6 +1182,247 @@ def _make_code(self):
return code


class SQLServerConnection(Connection):
"""Support for SQL Server connections to databases."""

def __init__(self, conn: Any):
self.conn = conn

try:
self.host = self._fetch_one_value("SELECT @@SERVERNAME")
except Exception:
self.host = "<unknown>"

self.database: str | None = None
with contextlib.suppress(Exception):
db_name = self._fetch_one_value("SELECT DB_NAME()")
self.database = str(db_name) if db_name not in (None, "") else None

self.type = "SQLServer" + (" (pyodbc)" if self._is_pyodbc() else " (pymssql)")
self.display_name = f"{self.type} - {self.host}"

self.code = self._make_code()
self.icon = ""

def disconnect(self):
with contextlib.suppress(Exception):
self.conn.close()

def list_object_types(self):
return {
"database": ConnectionObjectInfo({"contains": None, "icon": None}),
"schema": ConnectionObjectInfo({"contains": None, "icon": None}),
"table": ConnectionObjectInfo({"contains": "data", "icon": None}),
"view": ConnectionObjectInfo({"contains": "data", "icon": None}),
}

def list_objects(self, path: list[ObjectSchema]):
if len(path) == 0:
rows = self._execute("SELECT name FROM sys.databases ORDER BY name;")
return [ConnectionObject({"name": row[0], "kind": "database"}) for row in rows]

if len(path) == 1:
database = path[0]
if database.kind != "database":
raise ValueError(
f"Invalid path. Expected it to include a database, but got '{database.kind}'. Path: {path}"
)

rows = self._execute(
f"SELECT name FROM {self._qualify(database.name, 'sys', 'schemas')} ORDER BY name;"
)
return [ConnectionObject({"name": row[0], "kind": "schema"}) for row in rows]

if len(path) == 2:
database, schema = path
if database.kind != "database" or schema.kind != "schema":
raise ValueError(
"Path must include a database and schema in this order. "
f"Got database.kind={database.kind}, schema.kind={schema.kind}. Path: {path}"
)

rows = self._execute(
f"""
SELECT TABLE_NAME, TABLE_TYPE
FROM {self._qualify(database.name, "INFORMATION_SCHEMA", "TABLES")}
WHERE TABLE_SCHEMA = {self._quote_literal(schema.name)}
ORDER BY TABLE_NAME;
"""
)

objects: list[ConnectionObject] = []
for table_name, table_type in rows:
kind = "view" if "VIEW" in str(table_type).upper() else "table"
objects.append(ConnectionObject({"name": table_name, "kind": kind}))
return objects

raise ValueError(f"Path length must be at most 2, but got {len(path)}. Path: {path}")

def list_fields(self, path: list[ObjectSchema]):
if len(path) != 3:
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")

database, schema, table = path
if (
database.kind != "database"
or schema.kind != "schema"
or table.kind not in ["table", "view"]
):
raise ValueError(
"Path must include a database, schema and table/view in this order. "
f"Got database.kind={database.kind}, schema.kind={schema.kind}, table.kind={table.kind}. "
f"Path: {path}"
)

rows = self._execute(
f"""
SELECT COLUMN_NAME, DATA_TYPE
FROM {self._qualify(database.name, "INFORMATION_SCHEMA", "COLUMNS")}
WHERE TABLE_SCHEMA = {self._quote_literal(schema.name)} AND TABLE_NAME = {self._quote_literal(table.name)}
ORDER BY ORDINAL_POSITION;
"""
)

return [ConnectionObjectFields({"name": name, "dtype": dtype}) for name, dtype in rows]

def preview_object(self, path: list[ObjectSchema], var_name: str | None = None):
try:
import pandas as pd
except ImportError as e:
raise ModuleNotFoundError("Pandas is required for previewing SQL Server tables.") from e

if len(path) != 3:
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")

database, schema, table = path
if (
database.kind != "database"
or schema.kind != "schema"
or table.kind not in ["table", "view"]
):
raise ValueError(
"Path must include a database, schema and table/view in this order. "
f"Got database.kind={database.kind}, schema.kind={schema.kind}, table.kind={table.kind}. "
f"Path: {path}"
)

qualified_name = self._qualify(database.name, schema.name, table.name)
query = f"SELECT TOP 1000 * FROM {qualified_name};"
cursor = self.conn.cursor()
try:
cursor.execute(query)
rows = cursor.fetchall()
cols = [c[0] for c in cursor.description or []]
finally:
with contextlib.suppress(Exception):
cursor.close()

if self._is_pyodbc():
# pyodbc returns rows as pyodbc.Row, which pandas cannot handle directly
rows = [tuple(row) for row in rows]

preview_df = pd.DataFrame(rows, columns=cols)

var_name = var_name or "conn"
sql_string = (
f"# {table.name} = pd.read_sql({query!r}, {var_name}) "
f"# where {var_name} is your connection variable"
)
return preview_df, sql_string

def _execute(self, sql: str) -> list[tuple[Any, ...]]:
cursor = self.conn.cursor()
try:
cursor.execute(sql)
return cursor.fetchall()
finally:
with contextlib.suppress(Exception):
cursor.close()

def _fetch_one_value(self, sql: str) -> Any:
rows = self._execute(sql)
return rows[0][0] if rows else None

def _qualify(self, *parts: str) -> str:
return ".".join(self._quote_identifier(part) for part in parts)

def _quote_identifier(self, identifier: str) -> str:
return f"[{identifier.replace(']', ']]')}]"

def _quote_literal(self, value: str) -> str:
return "'" + value.replace("'", "''") + "'"

def _is_pyodbc(self) -> bool:
return safe_isinstance(self.conn, "pyodbc", "Connection")

def _pyodbc_getinfo(self, attr: str) -> str | None:
try:
import pyodbc # type: ignore
except ImportError:
return None

constant = getattr(pyodbc, attr, None)
if constant is None:
return None

if not safe_isinstance(self.conn, "pyodbc", "Connection"):
return None

try:
value = self.conn.getinfo(constant)
return str(value) if value is not None else None
except Exception:
return None

def _pyodbc_connection_string(self) -> str | None:
driver = self._pyodbc_getinfo("SQL_DRIVER_NAME")
server = self._pyodbc_getinfo("SQL_SERVER_NAME")
if server is None and self.host != "<unknown>":
server = self.host
database = self._pyodbc_getinfo("SQL_DATABASE_NAME") or self.database

if driver is None and server is None and database is None:
return None

parts = []
if driver:
parts.append(f"DRIVER={{{driver}}}")
if server:
parts.append(f"SERVER={server}")
if database:
parts.append(f"DATABASE={database}")
parts.append("Trusted_Connection=yes")
return ";".join(parts) + ";"

def _default_connection_string(self) -> str:
server = self.host if self.host != "<unknown>" else "<server>"
database = self.database or "<database>"
return (
"DRIVER={ODBC Driver 18 for SQL Server};"
f"SERVER={server};"
f"DATABASE={database};"
"Trusted_Connection=yes;"
)

def _make_code(self):
if self._is_pyodbc():
conn_str = self._pyodbc_connection_string() or self._default_connection_string()
return f"import pyodbc\nconn = pyodbc.connect({conn_str!r})\n%connection_show conn\n"

server = self.host if self.host != "<unknown>" else "<server>"
database = self.database or "<database>"
return (
"import pymssql\n"
"conn = pymssql.connect(\n"
f" server={server!r},\n"
f" database={database!r},\n"
" user='<username>', # TODO: Replace with your username\n"
" password='<password>', # TODO: Replace with your password\n"
")\n"
"%connection_show conn\n"
)


class DatabricksConnection(Connection):
"""Support for Databricks connections to databases."""

Expand Down
Loading