1111from .settings import config
1212from . import errors
1313from .dependencies import Dependencies
14+ from .plugin import connection_plugins
1415
1516logger = logging .getLogger (__name__ )
1617query_log_max_length = 300
1718
1819
20+ def get_host_hook (host_input ):
21+ if '://' in host_input :
22+ plugin_name = host_input .split ('://' )[0 ]
23+ try :
24+ return connection_plugins [plugin_name ]['object' ].load ().get_host (host_input )
25+ except KeyError :
26+ raise errors .DataJointError (
27+ "Connection plugin '{}' not found." .format (plugin_name ))
28+ else :
29+ return host_input
30+
31+
32+ def connect_host_hook (connection_obj ):
33+ if '://' in connection_obj .conn_info ['host_input' ]:
34+ plugin_name = connection_obj .conn_info ['host_input' ].split ('://' )[0 ]
35+ try :
36+ connection_plugins [plugin_name ]['object' ].load ().connect_host (connection_obj )
37+ except KeyError :
38+ raise errors .DataJointError (
39+ "Connection plugin '{}' not found." .format (plugin_name ))
40+ else :
41+ connection_obj .connect ()
42+
43+
1944def translate_query_error (client_error , query ):
2045 """
2146 Take client error and original query and return the corresponding DataJoint exception.
@@ -76,7 +101,8 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use
76101 #encrypted-connection-options).
77102 """
78103 if not hasattr (conn , 'connection' ) or reset :
79- host = host if host is not None else config ['database.host' ]
104+ host_input = host if host is not None else config ['database.host' ]
105+ host = get_host_hook (host_input )
80106 user = user if user is not None else config ['database.user' ]
81107 password = password if password is not None else config ['database.password' ]
82108 if user is None : # pragma: no cover
@@ -85,7 +111,8 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use
85111 password = getpass (prompt = "Please enter DataJoint password: " )
86112 init_fun = init_fun if init_fun is not None else config ['connection.init_function' ]
87113 use_tls = use_tls if use_tls is not None else config ['database.use_tls' ]
88- conn .connection = Connection (host , user , password , None , init_fun , use_tls )
114+ conn .connection = Connection (host , user , password , None , init_fun , use_tls ,
115+ host_input = host_input )
89116 return conn .connection
90117
91118
@@ -104,7 +131,8 @@ class Connection:
104131 :param use_tls: TLS encryption option
105132 """
106133
107- def __init__ (self , host , user , password , port = None , init_fun = None , use_tls = None ):
134+ def __init__ (self , host , user , password , port = None , init_fun = None , use_tls = None ,
135+ host_input = None ):
108136 if ':' in host :
109137 # the port in the hostname overrides the port argument
110138 host , port = host .split (':' )
@@ -115,10 +143,11 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
115143 if use_tls is not False :
116144 self .conn_info ['ssl' ] = use_tls if isinstance (use_tls , dict ) else {'ssl' : {}}
117145 self .conn_info ['ssl_input' ] = use_tls
146+ self .conn_info ['host_input' ] = host_input
118147 self .init_fun = init_fun
119148 print ("Connecting {user}@{host}:{port}" .format (** self .conn_info ))
120149 self ._conn = None
121- self . connect ( )
150+ connect_host_hook ( self )
122151 if self .is_connected :
123152 logger .info ("Connected {user}@{host}:{port}" .format (** self .conn_info ))
124153 self .connection_id = self .query ('SELECT connection_id()' ).fetchone ()[0 ]
@@ -149,15 +178,15 @@ def connect(self):
149178 "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
150179 charset = config ['connection.charset' ],
151180 ** {k : v for k , v in self .conn_info .items ()
152- if k != 'ssl_input' })
181+ if k not in [ 'ssl_input' , 'host_input' ] })
153182 except client .err .InternalError :
154183 self ._conn = client .connect (
155184 init_command = self .init_fun ,
156185 sql_mode = "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
157186 "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
158187 charset = config ['connection.charset' ],
159188 ** {k : v for k , v in self .conn_info .items ()
160- if not (k == 'ssl_input' or
189+ if not (k in [ 'ssl_input' , 'host_input' ] or
161190 k == 'ssl' and self .conn_info ['ssl_input' ] is None )})
162191 self ._conn .autocommit (True )
163192
@@ -194,7 +223,7 @@ def _execute_query(cursor, query, args, cursor_class, suppress_warnings):
194223 warnings .simplefilter ("ignore" )
195224 cursor .execute (query , args )
196225 except client .err .Error as err :
197- raise translate_query_error (err , query ) from None
226+ raise translate_query_error (err , query )
198227
199228 def query (self , query , args = (), * , as_dict = False , suppress_warnings = True , reconnect = None ):
200229 """
@@ -217,10 +246,10 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
217246 if not reconnect :
218247 raise
219248 warnings .warn ("MySQL server has gone away. Reconnecting to the server." )
220- self . connect ( )
249+ connect_host_hook ( self )
221250 if self ._in_transaction :
222251 self .cancel_transaction ()
223- raise errors .LostConnectionError ("Connection was lost during a transaction." ) from None
252+ raise errors .LostConnectionError ("Connection was lost during a transaction." )
224253 logger .debug ("Re-executing" )
225254 cursor = self ._conn .cursor (cursor = cursor_class )
226255 self ._execute_query (cursor , query , args , cursor_class , suppress_warnings )
0 commit comments