Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pytest==9.0.1; python_version >= '3.10'
pytest-asyncio==1.2.0; python_version == '3.9'
pytest-asyncio==1.3.0; python_version >= '3.10'
pytest-mock==3.15.1
redshift_connector==2.1.10; python_version < '3.14'
syrupy==4.9.1; python_version == '3.9'
syrupy==5.0.0; python_version >= '3.10'
torch==2.8.0; python_version == '3.9'
Expand Down
179 changes: 178 additions & 1 deletion extensions/positron-python/python_files/posit/positron/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ def _wrap_connection(self, obj: Any) -> Connection:
if not self.object_is_supported(obj):
type_name = type(obj).__name__
raise UnsupportedConnectionError(f"Unsupported connection type {type_name}")

if safe_isinstance(obj, "sqlite3", "Connection"):
return SQLite3Connection(obj)
elif safe_isinstance(obj, "sqlalchemy", "Engine"):
Expand All @@ -324,6 +323,8 @@ def _wrap_connection(self, obj: Any) -> Connection:
return SnowflakeConnection(obj)
elif safe_isinstance(obj, "databricks.sql.client", "Connection"):
return DatabricksConnection(obj)
elif safe_isinstance(obj, "redshift_connector", "Connection"):
return RedshiftConnection(obj)
else:
type_name = type(obj).__name__
raise UnsupportedConnectionError(f"Unsupported connection type {type(obj)}")
Expand All @@ -343,6 +344,7 @@ def object_is_supported(self, obj: Any) -> bool:
)
or safe_isinstance(obj, "snowflake.connector", "SnowflakeConnection")
or safe_isinstance(obj, "databricks.sql.client", "Connection")
or safe_isinstance(obj, "redshift_connector", "Connection")
)
except Exception as err:
logger.error(f"Error checking supported {err}")
Expand Down Expand Up @@ -1342,3 +1344,178 @@ def _make_code(self) -> str:
")\n"
"%connection_show con\n"
)


class RedshiftConnection(Connection):
"""Support for Redshift connections to databases."""

def __init__(self, conn: Any):
self.conn = conn

try:
# Unfortunately there's no public API to get the host, so we access the protected member.
# to at least provide some info in the connection display name.
host, _ = conn._usock.getpeername() # noqa: SLF001
except AttributeError:
host = "<unknown>"

self.host = str(host)

self.display_name = f"Redshift ({self.host})"
self.type = "Redshift"
self.code = self._make_code()

self.icon = ""

def disconnect(self):
with contextlib.suppress(Exception):
self.conn.close()

def list_object_types(self):
return {
"database": ConnectionObjectInfo({"contains": None, "icon": None}),
"schema": ConnectionObjectInfo({"contains": None, "icon": None}),
"table": ConnectionObjectInfo({"contains": "data", "icon": None}),
"view": ConnectionObjectInfo({"contains": "data", "icon": None}),
}

def list_objects(self, path: list[ObjectSchema]):
if len(path) == 0:
rows = self._query("SHOW DATABASES;")
return [
ConnectionObject({"name": row["database_name"], "kind": "database"}) for row in rows
]

if len(path) == 1:
database = path[0]
if database.kind != "database":
raise ValueError("Expected database on path position 0.", f"Path: {path}")
database_ident = self._qualify(database.name)
rows = self._query(f"SHOW SCHEMAS FROM DATABASE {database_ident};")
return [
ConnectionObject(
{
"name": row["schema_name"],
"kind": "schema",
}
)
for row in rows
]

if len(path) == 2:
database, schema = path
if database.kind != "database" or schema.kind != "schema":
raise ValueError(
"Expected database and schema objects at positions 0 and 1.", f"Path: {path}"
)
location = f"{self._qualify(database.name)}.{self._qualify(schema.name)}"
tables = self._query(f"SHOW TABLES FROM SCHEMA {location};")
return [
ConnectionObject(
{
"name": row["table_name"],
"kind": row["table_type"].lower(),
}
)
for row in tables
]

raise ValueError(f"Path length must be at most 2, but got {len(path)}. Path: {path}")

def list_fields(self, path: list[ObjectSchema]):
if len(path) != 3:
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")

database, schema, table = path
if (
database.kind != "database"
or schema.kind != "schema"
or table.kind not in ("table", "view")
):
raise ValueError(
"Expected database, schema, and table/view kinds in the path.",
f"Path: {path}",
)

identifier = ".".join(
[self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)]
)
rows = self._query(f"SHOW COLUMNS FROM TABLE {identifier};")
return [
ConnectionObjectFields(
{
"name": row["column_name"],
"dtype": row["data_type"],
}
)
for row in rows
]

def preview_object(self, path: list[ObjectSchema], var_name: str | None = None):
if len(path) != 3:
raise ValueError(f"Path length must be 3, but got {len(path)}. Path: {path}")

database, schema, table = path
if (
database.kind != "database"
or schema.kind != "schema"
or table.kind not in ("table", "view")
):
raise ValueError(
"Expected database, schema, and table/view kinds in the path.",
f"Path: {path}",
)

identifier = ".".join(
[self._qualify(database.name), self._qualify(schema.name), self._qualify(table.name)]
)
sql = f"SELECT * FROM {identifier} LIMIT 1000;"

with self.conn.cursor() as cursor:
try:
cursor.execute(sql)
frame = cursor.fetch_dataframe()
except Exception:
# Rollback on error to avoid transaction issues
# for subsequent queries
self.conn.rollback()
raise

var_name = var_name or "conn"
return frame, (
f"with {var_name}.cursor() as cursor:\n"
f" cursor.execute({sql!r})\n"
f" {table.name} = cursor.fetch_dataframe()"
)

def _query(self, sql: str) -> list[dict[str, Any]]:
cursor = self.conn.cursor()
try:
cursor.execute(sql)
rows = cursor.fetchall()
description = cursor.description or []
columns = [col[0] for col in description]
return [dict(zip(columns, row)) for row in rows]
except Exception:
# Rollback on error to avoid transaction issues
# for subsequent queries
self.conn.rollback()
raise
finally:
cursor.close()

def _qualify(self, identifier: str) -> str:
escaped = identifier.replace('"', '""')
return f'"{escaped}"'

def _make_code(self) -> str:
return (
"# Requires redshift-connector package\n"
"# Authentication steps may be incomplete, adjust as needed.\n"
"import redshift_connector\n"
"con = redshift_connector.connect(\n"
" iam = True,\n"
f" host = '{self.host}',\n"
")\n"
"%connection_show con\n"
)
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,18 @@ def _is_active(self, value) -> bool:
return True


class RedshiftConnectionInspector(BaseConnectionInspector):
CLASS_QNAME = ("redshift_connector.Connection", "redshift_connector.core.Connection")

def _is_active(self, value) -> bool:
try:
# a connection is active if you can acquire a cursor from it
value.cursor()
except Exception:
return False
return True


class IbisExprInspector(PositronInspector["ibis.Expr"]):
def has_children(self) -> bool:
return False
Expand Down Expand Up @@ -1290,6 +1302,7 @@ def to_plaintext(self) -> str:
**dict.fromkeys(SnowflakeConnectionInspector.CLASS_QNAME, SnowflakeConnectionInspector),
**dict.fromkeys(DatabricksConnectionInspector.CLASS_QNAME, DatabricksConnectionInspector),
**dict.fromkeys(BigQueryConnectionInspector.CLASS_QNAME, BigQueryConnectionInspector),
**dict.fromkeys(RedshiftConnectionInspector.CLASS_QNAME, RedshiftConnectionInspector),
"ibis.Expr": IbisExprInspector,
"boolean": BooleanInspector,
"bytes": BytesInspector,
Expand Down
Loading
Loading