1414limitations under the License. 
1515""" 
1616
17+ from  __future__ import  annotations 
18+ 
19+ from  abc  import  ABC 
20+ from  abc  import  abstractmethod 
1721import  asyncio 
22+ from  functools  import  partial 
23+ import  logging 
1824import  os 
1925from  pathlib  import  Path 
20- import  socket 
21- import  selectors 
22- import  ssl 
26+ from  typing  import  Callable , List 
2327
24- from   google . cloud . sql . connector . exceptions   import   LocalProxyStartupError 
28+ logger   =   logging . getLogger ( name = __name__ ) 
2529
26- LOCAL_PROXY_MAX_MESSAGE_SIZE  =  10485760 
2730
31+ class  BaseProxyProtocol (asyncio .Protocol ):
32+     """ 
33+     A protocol to proxy data between two transports. 
34+     """ 
2835
29- class  Proxy :
30-     """Creates an "accept loop" async task which will open the unix server socket and listen for new connections.""" 
36+     def  __init__ (self , proxy : Proxy ):
37+         super ().__init__ ()
38+         self .proxy  =  proxy 
39+         self ._buffer  =  bytearray ()
40+         self ._target : asyncio .Transport  |  None  =  None 
41+         self .transport : asyncio .Transport  |  None  =  None 
42+         self ._cached : List [bytes ] =  []
43+         logger .debug (f"__init__  { self }  " )
44+ 
45+     def  connection_made (self , transport ):
46+         logger .debug (f"connection_made { self }  " )
47+         self .transport  =  transport 
48+ 
49+     def  data_received (self , data ):
50+         if  self ._target  is  None :
51+             self ._cached .append (data )
52+         else :
53+             self ._target .write (data )
54+ 
55+     def  set_target (self , target : asyncio .Transport ):
56+         logger .debug (f"set_target { self }  " )
57+         self ._target  =  target 
58+         if  self ._cached :
59+             self ._target .writelines (self ._cached )
60+             self ._cached  =  []
61+ 
62+     def  eof_received (self ):
63+         logger .debug (f"eof_received { self }  " )
64+         if  self ._target  is  not   None :
65+             self ._target .write_eof ()
66+ 
67+     def  connection_lost (self , exc : Exception  |  None ):
68+         logger .debug (f"connection_lost { exc }   { self }  " )
69+         if  self ._target  is  not   None :
70+             self ._target .close ()
71+ 
72+ 
73+ class  ProxyClientConnection :
74+     """ 
75+     Holds all of the tasks and details for a client proxy 
76+     """ 
3177
3278    def  __init__ (
3379        self ,
34-         connector ,
35-         instance_connection_string : str ,
36-         socket_path : str ,
37-         loop : asyncio .AbstractEventLoop ,
38-         ** kwargs 
39-     ) ->  None :
40-         """Keeps track of all the async tasks and starts the accept loop for new connections. 
41-          
42-         Args: 
43-             connector (Connector): The instance where this Proxy class was created. 
80+         client_transport : asyncio .Transport ,
81+         client_protocol : ClientToServerProtocol ,
82+     ):
83+         self .client_transport  =  client_transport 
84+         self .client_protocol  =  client_protocol 
85+         self .server_transport : asyncio .Transport  |  None  =  None 
86+         self .server_protocol : ServerToClientProtocol  |  None  =  None 
87+         self .task : asyncio .Task  |  None  =  None 
88+ 
89+     def  close (self ):
90+         logger .debug (f"closing { self }  " )
91+         if  self .client_transport  is  not   None :
92+             self ._close_transport (self .client_transport )
93+         if  self .server_transport  is  not   None :
94+             self ._close_transport (self .server_transport )
95+ 
96+     def  _close_transport (self , transport :asyncio .Transport ):
97+         if  transport .is_closing ():
98+             return 
99+         if  transport .can_write_eof ():
100+             transport .write_eof ()
101+         else :
102+             transport .close ()
103+ 
104+ class  ClientToServerProtocol (BaseProxyProtocol ):
105+     """ 
106+     Protocol to copy bytes from the unix socket client to the database server 
107+     """ 
108+ 
109+     def  __init__ (self , proxy : Proxy ):
110+         super ().__init__ (proxy )
111+         self ._buffer  =  bytearray ()
112+         self ._target : asyncio .Transport  |  None  =  None 
113+         logger .debug (f"__init__ { self }  " )
114+ 
115+     def  connection_made (self , transport ):
116+         # When a connection is made, open the server connection 
117+         super ().connection_made (transport )
118+         self .proxy ._handle_client_connection (transport , self )
44119
45-             instance_connection_string (str): The instance connection name of the 
46-                 Cloud SQL instance to connect to. Takes the form of 
47-                 "project-id:region:instance-name" 
48120
49-                 Example: "my-project:us-central1:my-instance" 
121+ class  ServerToClientProtocol (BaseProxyProtocol ):
122+     """ 
123+     Protocol to copy bytes from the database server to the client socket 
124+     """ 
50125
51-             socket_path (str): A system path that is going to be used to store the socket. 
126+     def  __init__ (self , proxy : Proxy , cconn : ProxyClientConnection ):
127+         super ().__init__ (proxy )
128+         self ._buffer  =  bytearray ()
129+         self ._target  =  cconn .client_transport 
130+         self ._client_protocol  =  cconn .client_protocol 
131+         logger .debug (f"__init__ { self }  " )
52132
53-             loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. 
133+     def  connection_made (self , transport ):
134+         super ().connection_made (transport )
135+         self ._client_protocol .set_target (transport )
54136
55-             **kwargs: Any driver-specific arguments to pass to the underlying 
56-                 driver .connect call. 
137+     def  connection_lost (self , exc : Exception  |  None ):
138+         super ().connection_lost (exc )
139+         self .proxy ._handle_server_connection_lost ()
140+ 
141+ class  ServerConnectionFactory (ABC ):
142+     """ 
143+     ServerConnectionFactory is an abstract class that provides connections to the service. 
144+     """ 
145+     @abstractmethod  
146+     async  def  connect (self , protocol_fn : Callable [[], asyncio .Protocol ]):
147+         """ 
148+         Establishes a connection to the server and configures it to use the protocol 
149+         returned from protocol_fn, with asyncio.EventLoop.create_connection(). 
150+         :param protocol_fn: the protocol function 
151+         :return: None 
57152        """ 
58-         self ._connection_tasks  =  []
59-         self ._addr  =  instance_connection_string 
60-         self ._kwargs  =  kwargs 
61-         self ._connector  =  connector 
153+         pass 
62154
63-         unix_socket  =  None 
155+ class  Proxy :
156+     """ 
157+     A class to represent a local Unix socket proxy for a Cloud SQL instance. 
158+     This class manages a Unix socket that listens for incoming connections and 
159+     proxies them to a Cloud SQL instance. 
160+     """ 
64161
65-         try :
66-             path_parts  =  socket_path .rsplit ('/' , 1 )
67-             parent_directory  =  '/' .join (path_parts [:- 1 ])
162+     def  __init__ (
163+         self ,
164+         unix_socket_path : str ,
165+         server_connection_factory : ServerConnectionFactory ,
166+         loop : asyncio .AbstractEventLoop ,
167+     ):
168+         """ 
169+         Creates a new Proxy 
170+         :param unix_socket_path: the path to listen for the proxy connection 
171+         :param loop: The event loop 
172+         :param instance_connect: A function that will establish the async connection to the server 
173+ 
174+         The instance_connect function is an asynchronous function that should set up a new connection. 
175+         It takes one argument - another function that 
176+         """ 
177+         self .unix_socket_path  =  unix_socket_path 
178+         self .alive  =  True 
179+         self ._loop  =  loop 
180+         self ._server : asyncio .AbstractServer  |  None  =  None 
181+         self ._client_connections : set [ProxyClientConnection ] =  set ()
182+         self ._server_connection_factory  =  server_connection_factory 
68183
69-             desired_path  =  Path (parent_directory )
70-             desired_path .mkdir (parents = True , exist_ok = True )
184+     async  def  start (self ) ->  None :
185+         """Starts the Unix socket server.""" 
186+         if  os .path .exists (self .unix_socket_path ):
187+             os .remove (self .unix_socket_path )
71188
72-              if   os . path . exists ( socket_path ): 
73-                  os . remove ( socket_path )
189+         parent_dir   =   Path ( self . unix_socket_path ). parent 
190+         parent_dir . mkdir ( parents = True ,  exist_ok = True )
74191
75-             unix_socket  =  socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
192+         def  new_protocol () ->  ClientToServerProtocol :
193+             return  ClientToServerProtocol (self )
76194
77-             unix_socket .bind (socket_path )
78-             unix_socket .listen (1 )
79-             unix_socket .setblocking (False )
80-             os .chmod (socket_path , 0o600 )
195+         logger .debug (f"Socket path: { self .unix_socket_path }  " )
196+         self ._server  =  await  self ._loop .create_unix_server (
197+             new_protocol , path = self .unix_socket_path 
198+         )
199+         self ._loop .create_task (self ._server .serve_forever ())
81200
82-             self ._task  =  loop .create_task (self .accept_loop (unix_socket , socket_path , loop ))
201+     def  _handle_client_connection (
202+         self ,
203+         client_transport : asyncio .Transport ,
204+         client_protocol : ClientToServerProtocol ,
205+     ) ->  None :
206+         """ 
207+         Register a new client connection and initiate the task to create a database connection. 
208+         This is called by ClientToServerProtocol.connection_made 
83209
84-         except  Exception :
85-             raise  LocalProxyStartupError (
86-                 'Local UNIX socket based proxy was not able to get started.' 
87-             )
210+         :param client_transport: the client transport for the client unix socket 
211+         :param client_protocol:  the instance for the 
212+         :return: None 
213+         """ 
214+         conn  =  ProxyClientConnection (client_transport , client_protocol )
215+         self ._client_connections .add (conn )
216+         conn .task  =  self ._loop .create_task (self ._create_db_instance_connection (conn ))
217+         conn .task .add_done_callback (lambda  _ : self ._client_connections .discard (conn ))
88218
89-     async   def  accept_loop (
219+     def  _handle_server_connection_lost (
90220        self ,
91-         unix_socket ,
92-         socket_path : str ,
93-         loop : asyncio .AbstractEventLoop 
94-     ) ->  asyncio .Task :
95-         """Starts a UNIX based local proxy for transporting messages through 
96-         the SSL Socket, and waits until there is a new connection to accept, to register it 
97-         and keep track of it. 
221+     ) ->  None :
222+         """ 
223+         Closes the proxy server if the connection to the server is lost 
98224         
99-         Args: 
100-             socket_path: A system path that is going to be used to store the socket. 
225+         :return: None 
226+         """ 
227+         logger .debug (f"Closing proxy server due to lost connection" )
228+         self ._loop .create_task (self .close ())
229+ 
230+     async  def  _create_db_instance_connection (self , conn : ProxyClientConnection ) ->  None :
231+         """ 
232+         Manages a single proxy connection from a client to the Cloud SQL instance. 
233+         """ 
234+         try :
235+             logger .debug ("_proxy_connection() started" )
236+             new_protocol  =  partial (ServerToClientProtocol , self , conn )
237+ 
238+             # Establish connection to the database 
239+             await  self ._server_connection_factory .connect (new_protocol )
240+             logger .debug ("_proxy_connection() succeeded" )
101241
102-             loop (asyncio.AbstractEventLoop): Event loop to run asyncio tasks. 
242+         except  Exception  as  e :
243+             logger .error (f"Error handling proxy connection: { e }  " )
244+             await  self .close ()
245+             raise  e 
103246
104-         Raises: 
105-             LocalProxyStartupError: Local UNIX socket based proxy was not able to 
106-             get started. 
247+     async  def  close (self ) ->  None :
107248        """ 
108-         print ("on accept loop" )
109-         while  True :
110-             client , _  =  await  loop .sock_accept (unix_socket )
111-             self ._connection_tasks .append (loop .create_task (self .client_socket (client , unix_socket , socket_path , loop ))) 
249+         Shuts down the proxy server and cleans up resources. 
250+         """ 
251+         logger .info (f"Closing Unix socket proxy at { self .unix_socket_path }  " )
112252
113-     async  def  close_async (self ):
114-         proxy_task  =  asyncio .gather (self ._task )
115-         try :
116-             await  asyncio .wait_for (proxy_task , timeout = 0.1 )
117-         except  (asyncio .CancelledError , asyncio .TimeoutError , TimeoutError ):
118-             pass  # This task runs forever so it is expected to throw this exception 
253+         if  self ._server :
254+             self ._server .close ()
255+             await  self ._server .wait_closed ()
119256
257+         if  self ._client_connections :
258+             for  conn  in  list (self ._client_connections ):
259+                 conn .close ()
260+             await  asyncio .wait ([c .task  for  c  in  self ._client_connections  if  c .task  is  not   None ], timeout = 0.1 )
120261
121-     async  def  client_socket (
122-         self , client , unix_socket , socket_path , loop 
123-     ):
124-         try :
125-             ssl_sock  =  self ._connector .connect (
126-                 self ._addr ,
127-                 'local_unix_socket' ,
128-                 ** self ._kwargs 
129-             )
130-             while  True :
131-                 data  =  await  loop .sock_recv (client , LOCAL_PROXY_MAX_MESSAGE_SIZE )
132-                 if  not  data :
133-                     client .close ()
134-                     break 
135-                 ssl_sock .sendall (data )
136-                 response  =  ssl_sock .recv (LOCAL_PROXY_MAX_MESSAGE_SIZE )
137-                 await  loop .sock_sendall (client , response )
138-         finally :
139-             client .close ()
140-             os .remove (socket_path ) # Clean up the socket file 
262+         if  os .path .exists (self .unix_socket_path ):
263+             os .remove (self .unix_socket_path )
264+ 
265+         logger .info (f"Unix socket proxy for { self .unix_socket_path }   closed." )
266+         self .alive  =  False 
0 commit comments