66import io
77
88from sqlalchemy import create_engine , text
9- from sqlalchemy .engine import Engine
9+ from sqlalchemy .engine import Engine , Connection
1010import pandas as pd
1111import sqlglot .expressions as exp
1212import sqlglot
@@ -59,30 +59,62 @@ class Connector:
5959 }
6060
6161 def __init__ (self , url : str , view_sql : str , engine_params : Optional [Dict [str , Any ]] = None ) -> "Connector" :
62- _check_view_sql (view_sql )
6362 if engine_params is None :
6463 engine_params = {}
6564
66- self .url = url
67- self .engine = self ._get_engine (engine_params )
65+ self ._init_instance (self ._get_or_create_engine (url , engine_params ), view_sql )
66+
67+ @classmethod
68+ def from_sqlalchemy_engine (cls , engine : Engine , view_sql : str ) -> "Connector" :
69+ """Create connector from engine"""
70+ instance = cls .__new__ (cls )
71+ instance ._init_instance (engine , view_sql )
72+ return instance
73+
74+ @classmethod
75+ def from_sqlalchemy_connection (cls , connection : Connection , view_sql : str ) -> "Connector" :
76+ """
77+ Create a Connector instance from an existing SQLAlchemy connection.
78+ This adapts the DuckDB connector.
79+
80+ Note:
81+ - All subsequent queries will use the same connection.
82+ - The caller is responsible for managing and closing the connection when no longer needed.
83+ """
84+ instance = cls .__new__ (cls )
85+ instance ._init_instance (connection .engine , view_sql )
86+ instance ._existing_conn = connection
87+ return instance
88+
89+ def _init_instance (self , engine : Engine , view_sql : str ):
90+ _check_view_sql (view_sql )
91+ self .engine = engine
92+ self .url = str (engine .url )
6893 self .view_sql = view_sql
6994 self ._json_type_code_set = self .JSON_TYPE_CODE_SET_MAP .get (self .dialect_name , set ())
95+ self ._existing_conn = None
96+ self ._run_pre_init_sql (engine )
7097
71- def _get_engine (self , engine_params : Dict [str , Any ]) -> Engine :
72- if self . url not in self .engine_map :
73- engine = create_engine (self . url , ** engine_params )
98+ def _get_or_create_engine (self , url : str , engine_params : Dict [str , Any ]) -> Engine :
99+ if url not in self .engine_map :
100+ engine = create_engine (url , ** engine_params )
74101 engine .dialect .requires_name_normalize = False
75- self .engine_map [self .url ] = engine
76- if engine .dialect .name in self .PRE_INIT_SQL_MAP :
77- pre_init_sql = self .PRE_INIT_SQL_MAP [engine .dialect .name ]
78- with engine .connect (True ) as connection :
79- connection .execute (text (pre_init_sql ))
102+ self .engine_map [url ] = engine
103+
104+ return self .engine_map [url ]
80105
81- return self .engine_map [self .url ]
106+ def _run_pre_init_sql (self , engine : Engine ) -> None :
107+ if engine .dialect .name in self .PRE_INIT_SQL_MAP :
108+ pre_init_sql = self .PRE_INIT_SQL_MAP [engine .dialect .name ]
109+ with engine .connect (True ) as connection :
110+ connection .execute (text (pre_init_sql ))
82111
83112 def query_datas (self , sql : str ) -> List [Dict [str , Any ]]:
84113 field_type_map = {}
85- with self .engine .connect () as connection :
114+ should_close_connection = self ._existing_conn is None
115+ connection = self ._existing_conn or self .engine .connect ()
116+
117+ try :
86118 result = connection .execute (text (sql ))
87119 if self .dialect_name in self .JSON_TYPE_CODE_SET_MAP :
88120 field_type_map = {
@@ -96,6 +128,9 @@ def query_datas(self, sql: str) -> List[Dict[str, Any]]:
96128 }
97129 for item in result .mappings ()
98130 ]
131+ finally :
132+ if should_close_connection :
133+ connection .close ()
99134
100135 @property
101136 def dialect_name (self ) -> str :
0 commit comments