1717from  __future__ import  annotations 
1818
1919import  asyncio 
20- from  dataclasses  import  dataclass 
2120from  enum  import  Enum 
2221import  logging 
2322import  re 
24- import  ssl 
25- from  tempfile  import  TemporaryDirectory 
26- from  typing  import  Any , Dict , Tuple , TYPE_CHECKING 
23+ from  typing  import  Tuple 
2724
2825import  aiohttp 
29- from  google .auth .credentials  import  TokenState 
30- from  google .auth .transport  import  requests 
3126
3227from  google .cloud .sql .connector .client  import  CloudSQLClient 
33- from  google .cloud .sql .connector .exceptions  import  AutoIAMAuthNotSupported 
34- from  google .cloud .sql .connector .exceptions  import  CloudSQLIPTypeError 
28+ from  google .cloud .sql .connector .connection_info  import  ConnectionInfo 
3529from  google .cloud .sql .connector .exceptions  import  RefreshNotValidError 
36- from  google .cloud .sql .connector .exceptions  import  TLSVersionError 
3730from  google .cloud .sql .connector .rate_limiter  import  AsyncRateLimiter 
3831from  google .cloud .sql .connector .refresh_utils  import  _is_valid 
3932from  google .cloud .sql .connector .refresh_utils  import  _seconds_until_refresh 
40- from  google .cloud .sql .connector .utils  import  write_to_file 
41- 
42- if  TYPE_CHECKING :
43-     import  datetime 
4433
4534logger  =  logging .getLogger (name = __name__ )
4635
@@ -83,79 +72,6 @@ def _from_str(cls, ip_type_str: str) -> IPTypes:
8372        return  cls (ip_type_str .upper ())
8473
8574
86- @dataclass  
87- class  ConnectionInfo :
88-     """Contains all necessary information to connect securely to the 
89-     server-side Proxy running on a Cloud SQL instance.""" 
90- 
91-     client_cert : str 
92-     server_ca_cert : str 
93-     private_key : bytes 
94-     ip_addrs : Dict [str , Any ]
95-     database_version : str 
96-     expiration : datetime .datetime 
97-     context : ssl .SSLContext  |  None  =  None 
98- 
99-     def  create_ssl_context (self , enable_iam_auth : bool  =  False ) ->  ssl .SSLContext :
100-         """Constructs a SSL/TLS context for the given connection info. 
101- 
102-         Cache the SSL context to ensure we don't read from disk repeatedly when 
103-         configuring a secure connection. 
104-         """ 
105-         # if SSL context is cached, use it 
106-         if  self .context  is  not   None :
107-             return  self .context 
108-         context  =  ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
109- 
110-         # update ssl.PROTOCOL_TLS_CLIENT default 
111-         context .check_hostname  =  False 
112- 
113-         # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been 
114-         # implemented. The ssl module requires OpenSSL 1.1.1 or newer. 
115-         # verify OpenSSL version supports TLSv1.3 
116-         if  ssl .HAS_TLSv1_3 :
117-             # force TLSv1.3 if supported by client 
118-             context .minimum_version  =  ssl .TLSVersion .TLSv1_3 
119-         # fallback to TLSv1.2 for older versions of OpenSSL 
120-         else :
121-             if  enable_iam_auth :
122-                 raise  TLSVersionError (
123-                     f"Your current version of OpenSSL ({ ssl .OPENSSL_VERSION }  ) does not " 
124-                     "support TLSv1.3, which is required to use IAM Authentication.\n " 
125-                     "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." 
126-                 )
127-             logger .warning (
128-                 "TLSv1.3 is not supported with your version of OpenSSL " 
129-                 f"({ ssl .OPENSSL_VERSION }  ), falling back to TLSv1.2\n " 
130-                 "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support." 
131-             )
132-             context .minimum_version  =  ssl .TLSVersion .TLSv1_2 
133- 
134-         # tmpdir and its contents are automatically deleted after the CA cert 
135-         # and ephemeral cert are loaded into the SSLcontext. The values 
136-         # need to be written to files in order to be loaded by the SSLContext 
137-         with  TemporaryDirectory () as  tmpdir :
138-             ca_filename , cert_filename , key_filename  =  write_to_file (
139-                 tmpdir , self .server_ca_cert , self .client_cert , self .private_key 
140-             )
141-             context .load_cert_chain (cert_filename , keyfile = key_filename )
142-             context .load_verify_locations (cafile = ca_filename )
143-         # set class attribute to cache context for subsequent calls 
144-         self .context  =  context 
145-         return  context 
146- 
147-     def  get_preferred_ip (self , ip_type : IPTypes ) ->  str :
148-         """Returns the first IP address for the instance, according to the preference 
149-         supplied by ip_type. If no IP addressess with the given preference are found, 
150-         an error is raised.""" 
151-         if  ip_type .value  in  self .ip_addrs :
152-             return  self .ip_addrs [ip_type .value ]
153-         raise  CloudSQLIPTypeError (
154-             "Cloud SQL instance does not have any IP addresses matching " 
155-             f"preference: { ip_type .value }  )" 
156-         )
157- 
158- 
15975class  RefreshAheadCache :
16076    """Cache that refreshes connection info in the background prior to expiration. 
16177
@@ -229,45 +145,13 @@ async def _perform_refresh(self) -> ConnectionInfo:
229145
230146        try :
231147            await  self ._refresh_rate_limiter .acquire ()
232-             priv_key , pub_key  =  await  self ._keys 
233- 
234-             logger .debug (f"['{ self ._instance_connection_string }  ']: Creating context" )
235- 
236-             # before making Cloud SQL Admin API calls, refresh creds 
237-             if  not  self ._client ._credentials .token_state  ==  TokenState .FRESH :
238-                 self ._client ._credentials .refresh (requests .Request ())
239- 
240-             metadata_task  =  asyncio .create_task (
241-                 self ._client ._get_metadata (
242-                     self ._project ,
243-                     self ._region ,
244-                     self ._instance ,
245-                 )
246-             )
247- 
248-             ephemeral_task  =  asyncio .create_task (
249-                 self ._client ._get_ephemeral (
250-                     self ._project ,
251-                     self ._instance ,
252-                     pub_key ,
253-                     self ._enable_iam_auth ,
254-                 )
148+             connection_info  =  await  self ._client .get_connection_info (
149+                 self ._project ,
150+                 self ._region ,
151+                 self ._instance ,
152+                 self ._keys ,
153+                 self ._enable_iam_auth ,
255154            )
256-             try :
257-                 metadata  =  await  metadata_task 
258-                 # check if automatic IAM database authn is supported for database engine 
259-                 if  self ._enable_iam_auth  and  not  metadata [
260-                     "database_version" 
261-                 ].startswith (("POSTGRES" , "MYSQL" )):
262-                     raise  AutoIAMAuthNotSupported (
263-                         f"'{ metadata ['database_version' ]}  ' does not support automatic IAM authentication. It is only supported with Cloud SQL Postgres or MySQL instances." 
264-                     )
265-             except  Exception :
266-                 # cancel ephemeral cert task if exception occurs before it is awaited 
267-                 ephemeral_task .cancel ()
268-                 raise 
269- 
270-             ephemeral_cert , expiration  =  await  ephemeral_task 
271155
272156        except  aiohttp .ClientResponseError  as  e :
273157            logger .debug (
@@ -285,15 +169,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
285169
286170        finally :
287171            self ._refresh_in_progress .clear ()
288- 
289-         return  ConnectionInfo (
290-             ephemeral_cert ,
291-             metadata ["server_ca_cert" ],
292-             priv_key ,
293-             metadata ["ip_addresses" ],
294-             metadata ["database_version" ],
295-             expiration ,
296-         )
172+         return  connection_info 
297173
298174    def  _schedule_refresh (self , delay : int ) ->  asyncio .Task :
299175        """ 
0 commit comments