Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,11 @@ PGVECTOR_COLLECTION=table_info_db

# VectorDB 설정
VECTORDB_TYPE=faiss # faiss 또는 pgvector


# TRINO_HOST=localhost
# TRINO_PORT=8080
# TRINO_USER=admin
# TRINO_PASSWORD=password
# TRINO_CATALOG=delta
# TRINO_SCHEMA=default
2 changes: 2 additions & 0 deletions db_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .duckdb_connector import DuckDBConnector
from .databricks_connector import DatabricksConnector
from .snowflake_connector import SnowflakeConnector
from .trino_connector import TrinoConnector

env_path = os.path.join(os.getcwd(), ".env")

Expand Down Expand Up @@ -54,6 +55,7 @@ def get_db_connector(db_type: Optional[str] = None, config: Optional[DBConfig] =
"duckdb": DuckDBConnector,
"databricks": DatabricksConnector,
"snowflake": SnowflakeConnector,
"trino": TrinoConnector,
}

if db_type not in connector_map:
Expand Down
120 changes: 120 additions & 0 deletions db_utils/trino_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import pandas as pd
from .base_connector import BaseConnector
from .config import DBConfig
from .logger import logger

try:
import trino
except Exception as e: # pragma: no cover
trino = None
_import_error = e


class TrinoConnector(BaseConnector):
"""
Connect to Trino and execute SQL queries.
"""

connection = None

def __init__(self, config: DBConfig):
"""
Initialize the TrinoConnector with connection parameters.

Parameters:
config (DBConfig): Configuration object containing connection parameters.
"""
self.host = config["host"]
self.port = config["port"] or 8080
self.user = config.get("user") or "anonymous"
self.password = config.get("password")
self.database = config.get("database") # e.g., catalog.schema
self.extra = config.get("extra") or {}
self.http_scheme = self.extra.get("http_scheme", "http")
self.catalog = self.extra.get("catalog")
self.schema = self.extra.get("schema")

# If database given as "catalog.schema", split into fields
if self.database and (not self.catalog or not self.schema):
if "." in self.database:
db_catalog, db_schema = self.database.split(".", 1)
self.catalog = self.catalog or db_catalog
self.schema = self.schema or db_schema

self.connect()

def connect(self) -> None:
"""
Establish a connection to the Trino cluster.
"""
if trino is None:
logger.error(f"Failed to import trino driver: {_import_error}")
raise _import_error

try:
auth = None
if self.password:
# If HTTP, ignore password to avoid insecure auth error
if self.http_scheme == "http":
logger.warning(
"Password provided for HTTP; ignoring password. "
"Set TRINO_HTTP_SCHEME=https to enable password authentication."
)
else:
# Basic auth over HTTPS
auth = trino.auth.BasicAuthentication(self.user, self.password)

self.connection = trino.dbapi.connect(
host=self.host,
port=self.port,
user=self.user,
http_scheme=self.http_scheme,
catalog=self.catalog,
schema=self.schema,
auth=auth,
# Optional: session properties
# session_properties={}
)
logger.info("Successfully connected to Trino.")
except Exception as e:
logger.error(f"Failed to connect to Trino: {e}")
raise

def run_sql(self, sql: str) -> pd.DataFrame:
"""
Execute a SQL query and return the result as a pandas DataFrame.

Parameters:
sql (str): SQL query string to be executed.

Returns:
pd.DataFrame: Result of the SQL query as a pandas DataFrame.
"""
try:
cursor = self.connection.cursor()
cursor.execute(sql)
columns = (
[desc[0] for desc in cursor.description] if cursor.description else []
)
rows = cursor.fetchall() if cursor.description else []
return pd.DataFrame(rows, columns=columns)
except Exception as e:
logger.error(f"Failed to execute SQL query on Trino: {e}")
raise
finally:
try:
cursor.close()
except Exception:
pass

def close(self) -> None:
"""
Close the connection to the Trino cluster.
"""
if self.connection:
try:
self.connection.close()
except Exception:
pass
logger.info("Connection to Trino closed.")
self.connection = None
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"google-cloud-bigquery>=3.20.1,<4.0.0",
"pgvector==0.3.6",
"langchain-postgres==0.0.15",
"trino>=0.329.0,<1.0.0",
]

[project.scripts]
Expand Down Expand Up @@ -82,4 +83,3 @@ dev-dependencies = [
"pytest>=8.3.5",
]