Skip to content

Commit bc1efef

Browse files
committed
add support for SQL server trought pyodbc and pymssql
1 parent 3b69d8f commit bc1efef

File tree

1 file changed

+250
-0
lines changed

1 file changed

+250
-0
lines changed

extensions/positron-python/python_files/posit/positron/connections.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,25 @@
4444
logger = logging.getLogger(__name__)
4545

4646

47+
def _is_pyodbc_sqlserver(conn: Any) -> bool:
48+
"""Return True if `conn` is a pyodbc connection to SQL Server."""
49+
if not safe_isinstance(conn, "pyodbc", "Connection"):
50+
return False
51+
52+
try:
53+
import pyodbc
54+
except ImportError:
55+
return False
56+
57+
try:
58+
dbms_name = str(conn.getinfo(pyodbc.SQL_DBMS_NAME))
59+
except Exception:
60+
return False
61+
62+
upper_name = dbms_name.upper()
63+
return "SQL SERVER" in upper_name or "AZURE SQL" in upper_name
64+
65+
4766
class ConnectionWarning(UserWarning):
4867
"""
4968
Warning raised when there are issues in the Connections Pane relevant to the user.
@@ -322,6 +341,8 @@ def _wrap_connection(self, obj: Any) -> Connection:
322341
return GoogleBigQueryConnection(obj)
323342
elif safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection"):
324343
return SnowflakeConnection(obj)
344+
elif _is_pyodbc_sqlserver(obj) or safe_isinstance(obj, "pymssql", "Connection"):
345+
return SQLServerConnection(obj)
325346
elif safe_isinstance(obj, "databricks.sql.client", "Connection"):
326347
return DatabricksConnection(obj)
327348
else:
@@ -343,6 +364,8 @@ def object_is_supported(self, obj: Any) -> bool:
343364
)
344365
or safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection")
345366
or safe_isinstance(obj, "databricks.sql.client", "Connection")
367+
or _is_pyodbc_sqlserver(obj)
368+
or safe_isinstance(obj, "pymssql", "Connection")
346369
)
347370
except Exception as err:
348371
logger.error(f"Error checking supported {err}")
@@ -1159,6 +1182,233 @@ def _make_code(self):
11591182
return code
11601183

11611184

1185+
class SQLServerConnection(Connection):
1186+
"""Support for SQL Server connections to databases."""
1187+
1188+
def __init__(self, conn: Any):
1189+
self.conn = conn
1190+
1191+
try:
1192+
self.host = self._fetch_one_value("SELECT @@SERVERNAME")
1193+
except Exception:
1194+
self.host = "<unknown>"
1195+
1196+
self.database: str | None = None
1197+
with contextlib.suppress(Exception):
1198+
db_name = self._fetch_one_value("SELECT DB_NAME()")
1199+
self.database = str(db_name) if db_name not in (None, "") else None
1200+
1201+
self.type = "SQLServer" + (" (pyodbc)" if self._is_pyodbc() else " (pymssql)")
1202+
self.display_name = f"{self.type} - {self.host}"
1203+
1204+
self.code = self._make_code()
1205+
self.icon = ""
1206+
1207+
def disconnect(self):
1208+
with contextlib.suppress(Exception):
1209+
self.conn.close()
1210+
1211+
def list_object_types(self):
1212+
return {
1213+
"database": ConnectionObjectInfo({"contains": None, "icon": None}),
1214+
"schema": ConnectionObjectInfo({"contains": None, "icon": None}),
1215+
"table": ConnectionObjectInfo({"contains": "data", "icon": None}),
1216+
"view": ConnectionObjectInfo({"contains": "data", "icon": None}),
1217+
}
1218+
1219+
def list_objects(self, path: list[ObjectSchema]):
1220+
if len(path) == 0:
1221+
rows = self._execute("SELECT name FROM sys.databases ORDER BY name;")
1222+
return [ConnectionObject({"name": row[0], "kind": "database"}) for row in rows]
1223+
1224+
if len(path) == 1:
1225+
database = path[0]
1226+
if database.kind != "database":
1227+
raise ValueError(
1228+
f"Invalid path. Expected it to include a database, but got '{database.kind}'. Path: {path}"
1229+
)
1230+
1231+
rows = self._execute(
1232+
f"SELECT name FROM {self._qualify(database.name, 'sys', 'schemas')} ORDER BY name;"
1233+
)
1234+
return [ConnectionObject({"name": row[0], "kind": "schema"}) for row in rows]
1235+
1236+
if len(path) == 2:
1237+
database, schema = path
1238+
if database.kind != "database" or schema.kind != "schema":
1239+
raise ValueError(
1240+
"Path must include a database and schema in this order. "
1241+
f"Got database.kind={database.kind}, schema.kind={schema.kind}. Path: {path}"
1242+
)
1243+
1244+
rows = self._execute(
1245+
f"""
1246+
SELECT TABLE_NAME, TABLE_TYPE
1247+
FROM {self._qualify(database.name, 'INFORMATION_SCHEMA', 'TABLES')}
1248+
WHERE TABLE_SCHEMA = {self._quote_literal(schema.name)}
1249+
ORDER BY TABLE_NAME;
1250+
"""
1251+
)
1252+
1253+
objects: list[ConnectionObject] = []
1254+
for table_name, table_type in rows:
1255+
kind = "view" if "VIEW" in str(table_type).upper() else "table"
1256+
objects.append(ConnectionObject({"name": table_name, "kind": kind}))
1257+
return objects
1258+
1259+
raise ValueError(f"Path length must be at most 2, but got {len(path)}. Path: {path}")
1260+
1261+
def list_fields(self, path: list[ObjectSchema]):
1262+
if len(path) != 3:
1263+
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")
1264+
1265+
database, schema, table = path
1266+
if (
1267+
database.kind != "database"
1268+
or schema.kind != "schema"
1269+
or table.kind not in ["table", "view"]
1270+
):
1271+
raise ValueError(
1272+
"Path must include a database, schema and table/view in this order. "
1273+
f"Got database.kind={database.kind}, schema.kind={schema.kind}, table.kind={table.kind}. "
1274+
f"Path: {path}"
1275+
)
1276+
1277+
rows = self._execute(
1278+
f"""
1279+
SELECT COLUMN_NAME, DATA_TYPE
1280+
FROM {self._qualify(database.name, 'INFORMATION_SCHEMA', 'COLUMNS')}
1281+
WHERE TABLE_SCHEMA = {self._quote_literal(schema.name)} AND TABLE_NAME = {self._quote_literal(table.name)}
1282+
ORDER BY ORDINAL_POSITION;
1283+
"""
1284+
)
1285+
1286+
return [ConnectionObjectFields({"name": name, "dtype": dtype}) for name, dtype in rows]
1287+
1288+
def preview_object(self, path: list[ObjectSchema], var_name: str | None = None):
1289+
try:
1290+
import pandas as pd
1291+
except ImportError as e:
1292+
raise ModuleNotFoundError("Pandas is required for previewing SQL Server tables.") from e
1293+
1294+
if len(path) != 3:
1295+
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")
1296+
1297+
database, schema, table = path
1298+
if (
1299+
database.kind != "database"
1300+
or schema.kind != "schema"
1301+
or table.kind not in ["table", "view"]
1302+
):
1303+
raise ValueError(
1304+
"Path must include a database, schema and table/view in this order. "
1305+
f"Got database.kind={database.kind}, schema.kind={schema.kind}, table.kind={table.kind}. "
1306+
f"Path: {path}"
1307+
)
1308+
1309+
qualified_name = self._qualify(database.name, schema.name, table.name)
1310+
query = f"SELECT TOP 1000 * FROM {qualified_name};"
1311+
var_name = var_name or "conn"
1312+
preview_df = pd.read_sql(query, self.conn)
1313+
sql_string = (
1314+
f"# {table.name} = pd.read_sql({query!r}, {var_name}) "
1315+
f"# where {var_name} is your connection variable"
1316+
)
1317+
return preview_df, sql_string
1318+
1319+
def _execute(self, sql: str) -> list[tuple[Any, ...]]:
1320+
cursor = self.conn.cursor()
1321+
try:
1322+
cursor.execute(sql)
1323+
return cursor.fetchall()
1324+
finally:
1325+
with contextlib.suppress(Exception):
1326+
cursor.close()
1327+
1328+
def _fetch_one_value(self, sql: str) -> Any:
1329+
rows = self._execute(sql)
1330+
return rows[0][0] if rows else None
1331+
1332+
def _qualify(self, *parts: str) -> str:
1333+
return ".".join(self._quote_identifier(part) for part in parts)
1334+
1335+
def _quote_identifier(self, identifier: str) -> str:
1336+
return f"[{identifier.replace(']', ']]')}]"
1337+
1338+
def _quote_literal(self, value: str) -> str:
1339+
return "'" + value.replace("'", "''") + "'"
1340+
1341+
def _is_pyodbc(self) -> bool:
1342+
return safe_isinstance(self.conn, "pyodbc", "Connection")
1343+
1344+
def _pyodbc_getinfo(self, attr: str) -> str | None:
1345+
try:
1346+
import pyodbc # type: ignore
1347+
except ImportError:
1348+
return None
1349+
1350+
constant = getattr(pyodbc, attr, None)
1351+
if constant is None:
1352+
return None
1353+
1354+
if not safe_isinstance(self.conn, "pyodbc", "Connection"):
1355+
return None
1356+
1357+
try:
1358+
value = self.conn.getinfo(constant)
1359+
return str(value) if value is not None else None
1360+
except Exception:
1361+
return None
1362+
1363+
def _pyodbc_connection_string(self) -> str | None:
1364+
driver = self._pyodbc_getinfo("SQL_DRIVER_NAME")
1365+
server = self._pyodbc_getinfo("SQL_SERVER_NAME")
1366+
if server is None and self.host != "<unknown>":
1367+
server = self.host
1368+
database = self._pyodbc_getinfo("SQL_DATABASE_NAME") or self.database
1369+
1370+
if driver is None and server is None and database is None:
1371+
return None
1372+
1373+
parts = []
1374+
if driver:
1375+
parts.append(f"DRIVER={{{driver}}}")
1376+
if server:
1377+
parts.append(f"SERVER={server}")
1378+
if database:
1379+
parts.append(f"DATABASE={database}")
1380+
parts.append("Trusted_Connection=yes")
1381+
return ";".join(parts) + ";"
1382+
1383+
def _default_connection_string(self) -> str:
1384+
server = self.host if self.host != "<unknown>" else "<server>"
1385+
database = self.database or "<database>"
1386+
return (
1387+
"DRIVER={ODBC Driver 18 for SQL Server};"
1388+
f"SERVER={server};"
1389+
f"DATABASE={database};"
1390+
"Trusted_Connection=yes;"
1391+
)
1392+
1393+
def _make_code(self):
1394+
if self._is_pyodbc():
1395+
conn_str = self._pyodbc_connection_string() or self._default_connection_string()
1396+
return (
1397+
"import pyodbc\n" f"conn = pyodbc.connect({conn_str!r})\n" "%connection_show conn\n"
1398+
)
1399+
1400+
server = self.host if self.host != "<unknown>" else "<server>"
1401+
database = self.database or "<database>"
1402+
return (
1403+
"import pymssql\n"
1404+
"conn = pymssql.connect(\n"
1405+
f" server={server!r},\n"
1406+
f" database={database!r},\n"
1407+
")\n"
1408+
"%connection_show conn\n"
1409+
)
1410+
1411+
11621412
class DatabricksConnection(Connection):
11631413
"""Support for Databricks connections to databases."""
11641414

0 commit comments

Comments
 (0)