1111from .settings import config
1212from . import errors
1313from .dependencies import Dependencies
14+ from .plugin import override
1415
1516# client errors to catch
1617client_errors = (client .err .InterfaceError , client .err .DatabaseError )
1718
1819
20+ def get_host_hook (host_input ):
21+ return host_input
22+
23+
24+ def connect_host_hook (connection_obj ):
25+ connection_obj .connect ()
26+
27+
28+ override ('connection' , globals (), ['get_host_hook' , 'connect_host_hook' ])
29+
30+
1931def translate_query_error (client_error , query ):
2032 """
2133 Take client error and original query and return the corresponding DataJoint exception.
@@ -76,7 +88,8 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use
7688 #encrypted-connection-options).
7789 """
7890 if not hasattr (conn , 'connection' ) or reset :
79- host = host if host is not None else config ['database.host' ]
91+ host_input = host if host is not None else config ['database.host' ]
92+ host = get_host_hook (host_input )
8093 user = user if user is not None else config ['database.user' ]
8194 password = password if password is not None else config ['database.password' ]
8295 if user is None : # pragma: no cover
@@ -85,7 +98,8 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use
8598 password = getpass (prompt = "Please enter DataJoint password: " )
8699 init_fun = init_fun if init_fun is not None else config ['connection.init_function' ]
87100 use_tls = use_tls if use_tls is not None else config ['database.use_tls' ]
88- conn .connection = Connection (host , user , password , None , init_fun , use_tls )
101+ conn .connection = Connection (host , user , password , None , init_fun , use_tls ,
102+ host_input = host_input )
89103 return conn .connection
90104
91105
@@ -104,7 +118,8 @@ class Connection:
104118 :param use_tls: TLS encryption option
105119 """
106120
107- def __init__ (self , host , user , password , port = None , init_fun = None , use_tls = None ):
121+ def __init__ (self , host , user , password , port = None , init_fun = None , use_tls = None ,
122+ host_input = None ):
108123 if ':' in host :
109124 # the port in the hostname overrides the port argument
110125 host , port = host .split (':' )
@@ -115,10 +130,11 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
115130 if use_tls is not False :
116131 self .conn_info ['ssl' ] = use_tls if isinstance (use_tls , dict ) else {'ssl' : {}}
117132 self .conn_info ['ssl_input' ] = use_tls
133+ self .conn_info ['host_input' ] = host_input
118134 self .init_fun = init_fun
119135 print ("Connecting {user}@{host}:{port}" .format (** self .conn_info ))
120136 self ._conn = None
121- self . connect ( )
137+ connect_host_hook ( self )
122138 if self .is_connected :
123139 logger .info ("Connected {user}@{host}:{port}" .format (** self .conn_info ))
124140 self .connection_id = self .query ('SELECT connection_id()' ).fetchone ()[0 ]
@@ -149,15 +165,15 @@ def connect(self):
149165 "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
150166 charset = config ['connection.charset' ],
151167 ** {k : v for k , v in self .conn_info .items ()
152- if k != 'ssl_input' })
168+ if k not in [ 'ssl_input' , 'host_input' ] })
153169 except client .err .InternalError :
154170 self ._conn = client .connect (
155171 init_command = self .init_fun ,
156172 sql_mode = "NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
157173 "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION" ,
158174 charset = config ['connection.charset' ],
159175 ** {k : v for k , v in self .conn_info .items ()
160- if not (k == 'ssl_input' or
176+ if not (k in [ 'ssl_input' , 'host_input' ] or
161177 k == 'ssl' and self .conn_info ['ssl_input' ] is None )})
162178 self ._conn .autocommit (True )
163179
@@ -193,7 +209,7 @@ def __execute_query(cursor, query, args, cursor_class, suppress_warnings):
193209 warnings .simplefilter ("ignore" )
194210 cursor .execute (query , args )
195211 except client_errors as err :
196- raise translate_query_error (err , query ) from None
212+ raise translate_query_error (err , query )
197213
198214 def query (self , query , args = (), * , as_dict = False , suppress_warnings = True , reconnect = None ):
199215 """
@@ -216,10 +232,10 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
216232 if not reconnect :
217233 raise
218234 warnings .warn ("MySQL server has gone away. Reconnecting to the server." )
219- self . connect ( )
235+ connect_host_hook ( self )
220236 if self ._in_transaction :
221237 self .cancel_transaction ()
222- raise errors .LostConnectionError ("Connection was lost during a transaction." ) from None
238+ raise errors .LostConnectionError ("Connection was lost during a transaction." )
223239 logger .debug ("Re-executing" )
224240 cursor = self ._conn .cursor (cursor = cursor_class )
225241 self .__execute_query (cursor , query , args , cursor_class , suppress_warnings )
0 commit comments