@@ -80,6 +80,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
8080 super (ConnectionSSLContext , self ).__init__ (* args , ** kwargs )
8181
8282
83+ class TLSVersionError (Exception ):
84+ """
85+ Raised when the required TLS protocol version is not supported.
86+ """
87+
88+ def __init__ (self , * args : Any ) -> None :
89+ super (TLSVersionError , self ).__init__ (self , * args )
90+
91+
8392class CloudSQLConnectionError (Exception ):
8493 """
8594 Raised when the provided connection string is not formatted
@@ -111,8 +120,16 @@ def __init__(
111120 private_key : bytes ,
112121 server_ca_cert : str ,
113122 expiration : datetime .datetime ,
123+ enable_iam_auth : bool ,
114124 ) -> None :
115125 self .ip_addrs = ip_addrs
126+
127+ if enable_iam_auth and not ssl .HAS_TLSv1_3 : # type: ignore
128+ raise TLSVersionError (
129+ "Your current version of OpenSSL does not support TLSv1.3, "
130+ "which is required to use IAM Authentication."
131+ )
132+
116133 self .context = ConnectionSSLContext ()
117134 self .expiration = expiration
118135
@@ -293,18 +310,20 @@ async def _get_instance_data(self) -> InstanceMetadata:
293310 expiration = datetime .datetime .strptime (
294311 x509 .get_notAfter ().decode ("ascii" ), "%Y%m%d%H%M%SZ"
295312 )
296- if self ._credentials is not None :
297- token_expiration : datetime .datetime = self ._credentials .expiry
298313
299- if expiration > token_expiration :
300- expiration = token_expiration
314+ if self ._enable_iam_auth :
315+ if self ._credentials is not None :
316+ token_expiration : datetime .datetime = self ._credentials .expiry
317+ if expiration > token_expiration :
318+ expiration = token_expiration
301319
302320 return InstanceMetadata (
303321 ephemeral_cert ,
304322 metadata ["ip_addresses" ],
305323 priv_key ,
306324 metadata ["server_ca_cert" ],
307325 expiration ,
326+ self ._enable_iam_auth ,
308327 )
309328
310329 def _auth_init (self ) -> None :
0 commit comments