Skip to content

Commit e61e1b6

Browse files
committed
trino 커넥터 추가
1 parent 1275902 commit e61e1b6

File tree

4 files changed

+131
-1
lines changed

4 files changed

+131
-1
lines changed

.env.example

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,11 @@ PGVECTOR_COLLECTION=table_info_db
153153

154154
# VectorDB 설정
155155
VECTORDB_TYPE=faiss # faiss 또는 pgvector
156+
157+
158+
# TRINO_HOST=localhost
159+
# TRINO_PORT=8080
160+
# TRINO_USER=admin
161+
# TRINO_PASSWORD=password
162+
# TRINO_CATALOG=delta
163+
# TRINO_SCHEMA=default

db_utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .duckdb_connector import DuckDBConnector
1414
from .databricks_connector import DatabricksConnector
1515
from .snowflake_connector import SnowflakeConnector
16+
from .trino_connector import TrinoConnector
1617

1718
env_path = os.path.join(os.getcwd(), ".env")
1819

@@ -54,6 +55,7 @@ def get_db_connector(db_type: Optional[str] = None, config: Optional[DBConfig] =
5455
"duckdb": DuckDBConnector,
5556
"databricks": DatabricksConnector,
5657
"snowflake": SnowflakeConnector,
58+
"trino": TrinoConnector,
5759
}
5860

5961
if db_type not in connector_map:

db_utils/trino_connector.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import pandas as pd
2+
from .base_connector import BaseConnector
3+
from .config import DBConfig
4+
from .logger import logger
5+
6+
try:
7+
import trino
8+
except Exception as e: # pragma: no cover
9+
trino = None
10+
_import_error = e
11+
12+
13+
class TrinoConnector(BaseConnector):
14+
"""
15+
Connect to Trino and execute SQL queries.
16+
"""
17+
18+
connection = None
19+
20+
def __init__(self, config: DBConfig):
21+
"""
22+
Initialize the TrinoConnector with connection parameters.
23+
24+
Parameters:
25+
config (DBConfig): Configuration object containing connection parameters.
26+
"""
27+
self.host = config["host"]
28+
self.port = config["port"] or 8080
29+
self.user = config.get("user") or "anonymous"
30+
self.password = config.get("password")
31+
self.database = config.get("database") # e.g., catalog.schema
32+
self.extra = config.get("extra") or {}
33+
self.http_scheme = self.extra.get("http_scheme", "http")
34+
self.catalog = self.extra.get("catalog")
35+
self.schema = self.extra.get("schema")
36+
37+
# If database given as "catalog.schema", split into fields
38+
if self.database and (not self.catalog or not self.schema):
39+
if "." in self.database:
40+
db_catalog, db_schema = self.database.split(".", 1)
41+
self.catalog = self.catalog or db_catalog
42+
self.schema = self.schema or db_schema
43+
44+
self.connect()
45+
46+
def connect(self) -> None:
47+
"""
48+
Establish a connection to the Trino cluster.
49+
"""
50+
if trino is None:
51+
logger.error(f"Failed to import trino driver: {_import_error}")
52+
raise _import_error
53+
54+
try:
55+
auth = None
56+
if self.password:
57+
# If HTTP, ignore password to avoid insecure auth error
58+
if self.http_scheme == "http":
59+
logger.warning(
60+
"Password provided for HTTP; ignoring password. "
61+
"Set TRINO_HTTP_SCHEME=https to enable password authentication."
62+
)
63+
else:
64+
# Basic auth over HTTPS
65+
auth = trino.auth.BasicAuthentication(self.user, self.password)
66+
67+
self.connection = trino.dbapi.connect(
68+
host=self.host,
69+
port=self.port,
70+
user=self.user,
71+
http_scheme=self.http_scheme,
72+
catalog=self.catalog,
73+
schema=self.schema,
74+
auth=auth,
75+
# Optional: session properties
76+
# session_properties={}
77+
)
78+
logger.info("Successfully connected to Trino.")
79+
except Exception as e:
80+
logger.error(f"Failed to connect to Trino: {e}")
81+
raise
82+
83+
def run_sql(self, sql: str) -> pd.DataFrame:
84+
"""
85+
Execute a SQL query and return the result as a pandas DataFrame.
86+
87+
Parameters:
88+
sql (str): SQL query string to be executed.
89+
90+
Returns:
91+
pd.DataFrame: Result of the SQL query as a pandas DataFrame.
92+
"""
93+
try:
94+
cursor = self.connection.cursor()
95+
cursor.execute(sql)
96+
columns = (
97+
[desc[0] for desc in cursor.description] if cursor.description else []
98+
)
99+
rows = cursor.fetchall() if cursor.description else []
100+
return pd.DataFrame(rows, columns=columns)
101+
except Exception as e:
102+
logger.error(f"Failed to execute SQL query on Trino: {e}")
103+
raise
104+
finally:
105+
try:
106+
cursor.close()
107+
except Exception:
108+
pass
109+
110+
def close(self) -> None:
111+
"""
112+
Close the connection to the Trino cluster.
113+
"""
114+
if self.connection:
115+
try:
116+
self.connection.close()
117+
except Exception:
118+
pass
119+
logger.info("Connection to Trino closed.")
120+
self.connection = None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"google-cloud-bigquery>=3.20.1,<4.0.0",
4646
"pgvector==0.3.6",
4747
"langchain-postgres==0.0.15",
48+
"trino>=0.329.0,<1.0.0",
4849
]
4950

5051
[project.scripts]
@@ -82,4 +83,3 @@ dev-dependencies = [
8283
"pytest>=8.3.5",
8384
]
8485

85-

0 commit comments

Comments
 (0)