8989ENGINES = {}
9090CONNECTIONS = {}
9191ENGINE_KEY = '%(username)s-%(connector_name)s'
92+ URL_PATTERN = '(?P<driver_name>.+?://)(?P<host>[^:/ ]+):(?P<port>[0-9]*).*'
9293
9394LOG = logging .getLogger (__name__ )
9495
@@ -171,6 +172,18 @@ def _create_engine(self):
171172 s3_staging_dir = url .rsplit ('s3_staging_dir=' , 1 )[1 ]
172173 url = url .replace (s3_staging_dir , urllib_quote_plus (s3_staging_dir ))
173174
175+ m = re .search (URL_PATTERN , url )
176+ driver_name = m .group ('driver_name' )
177+ if self .options .get ('has_impersonation' ):
178+ if not driver_name :
179+ raise QueryError ('Driver name of %(url)s could not be found and impersonation is turned on' % {'url' : url })
180+
181+ if driver_name .startswith ("phoenix" ):
182+ url = url .replace (driver_name , '%(driver_name)s%(username)s@' % {
183+ 'driver_name' : driver_name ,
184+ 'username' : self .user .username
185+ })
186+
174187 if self .options .get ('credentials_json' ):
175188 self .options ['credentials_info' ] = json .loads (
176189 self .options .pop ('credentials_json' )
@@ -183,7 +196,8 @@ def _create_engine(self):
183196 self .options .pop ('connect_args' )
184197 )
185198
186- if self .options .get ('has_impersonation' ):
199+ # phoenixdb does not support impersonation using principal_username parameter
200+ if self .options .get ('has_impersonation' ) and not driver_name .startswith ("phoenix" ):
187201 self .options .setdefault ('connect_args' , {}).setdefault ('principal_username' , self .user .username )
188202
189203 options = self .options .copy ()
@@ -258,7 +272,7 @@ def execute(self, notebook, snippet):
258272 }
259273 CONNECTIONS [guid ] = cache
260274
261- response = {
275+ response = {
262276 'sync' : False ,
263277 'has_result_set' : result .cursor != None ,
264278 'modified_row_count' : 0 ,
@@ -330,7 +344,7 @@ def check_status(self, notebook, snippet):
330344 @query_error_handler
331345 def progress (self , notebook , snippet , logs = '' ):
332346 progress = 50
333- if self .options ['url' ].startswith ('presto://' ) | self .options ['url' ].startswith ('trino://' ) :
347+ if self .options ['url' ].startswith ('presto://' ) | self .options ['url' ].startswith ('trino://' ):
334348 guid = snippet ['result' ]['handle' ]['guid' ]
335349 handle = CONNECTIONS .get (guid )
336350 stats = None
0 commit comments