Skip to content

Commit 9fbe87a

Browse files
refactor: use google.auth TokenState for credentials validity (#1089)
1 parent 148fc07 commit 9fbe87a

File tree

4 files changed

+40
-17
lines changed

4 files changed

+40
-17
lines changed

google/cloud/sql/connector/client.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import aiohttp
2222
from cryptography.hazmat.backends import default_backend
2323
from cryptography.x509 import load_pem_x509_certificate
24-
import google.auth.transport.requests
2524

2625
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
2726
from google.cloud.sql.connector.version import __version__ as version
@@ -113,9 +112,6 @@ async def _get_metadata(
113112
addresses and their type and a string representing the
114113
certificate authority.
115114
"""
116-
if not self._credentials.valid:
117-
request = google.auth.transport.requests.Request()
118-
self._credentials.refresh(request)
119115

120116
headers = {
121117
"Authorization": f"Bearer {self._credentials.token}",
@@ -177,10 +173,6 @@ async def _get_ephemeral(
177173

178174
logger.debug(f"['{instance}']: Requesting ephemeral certificate")
179175

180-
if not self._credentials.valid:
181-
request = google.auth.transport.requests.Request()
182-
self._credentials.refresh(request)
183-
184176
headers = {
185177
"Authorization": f"Bearer {self._credentials.token}",
186178
}

google/cloud/sql/connector/instance.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from typing import Any, Dict, Tuple, TYPE_CHECKING
2727

2828
import aiohttp
29+
from google.auth.credentials import TokenState
30+
from google.auth.transport import requests
2931

3032
from google.cloud.sql.connector.client import CloudSQLClient
3133
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
@@ -231,6 +233,10 @@ async def _perform_refresh(self) -> ConnectionInfo:
231233

232234
logger.debug(f"['{self._instance_connection_string}']: Creating context")
233235

236+
# before making Cloud SQL Admin API calls, refresh creds
237+
if not self._client._credentials.token_state == TokenState.FRESH:
238+
self._client._credentials.refresh(requests.Request())
239+
234240
metadata_task = asyncio.create_task(
235241
self._client._get_metadata(
236242
self._project,

tests/unit/mocks.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
# file containing all mocks used for Cloud SQL Python Connector unit tests
1718

1819
import datetime
1920
import json
2021
import ssl
2122
from tempfile import TemporaryDirectory
22-
from typing import Any, Callable, Dict, Optional, Tuple
23+
from typing import Any, Callable, Dict, Literal, Optional, Tuple
2324

2425
from aiohttp import web
2526
from cryptography import x509
@@ -28,7 +29,9 @@
2829
from cryptography.hazmat.primitives import serialization
2930
from cryptography.hazmat.primitives.asymmetric import rsa
3031
from cryptography.x509.oid import NameOID
32+
from google.auth import _helpers
3133
from google.auth.credentials import Credentials
34+
from google.auth.credentials import TokenState
3235

3336
from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN
3437
from google.cloud.sql.connector.utils import generate_keys
@@ -48,7 +51,7 @@ def __class__(self) -> Credentials:
4851
# set class type to google auth Credentials
4952
return Credentials
5053

51-
def refresh(self, request: Callable) -> None:
54+
def refresh(self, _: Callable) -> None:
5255
"""Refreshes the access token."""
5356
self.token = "12345"
5457
self.expiry = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
@@ -75,13 +78,33 @@ def universe_domain(self) -> str:
7578
return self._universe_domain
7679

7780
@property
78-
def valid(self) -> bool:
79-
"""Checks the validity of the credentials.
80-
81-
This is True if the credentials have a token and the token
82-
is not expired.
81+
def token_state(
82+
self,
83+
) -> Literal[TokenState.FRESH, TokenState.STALE, TokenState.INVALID]:
8384
"""
84-
return self.token is not None and not self.expired
85+
Tracks the state of a token.
86+
FRESH: The token is valid. It is not expired or close to expired, or the token has no expiry.
87+
STALE: The token is close to expired, and should be refreshed. The token can be used normally.
88+
INVALID: The token is expired or invalid. The token cannot be used for a normal operation.
89+
"""
90+
if self.token is None:
91+
return TokenState.INVALID
92+
93+
# Credentials that can't expire are always treated as fresh.
94+
if self.expiry is None:
95+
return TokenState.FRESH
96+
97+
expired = datetime.datetime.now(datetime.timezone.utc) >= self.expiry
98+
if expired:
99+
return TokenState.INVALID
100+
101+
is_stale = datetime.datetime.now(datetime.timezone.utc) >= (
102+
self.expiry - _helpers.REFRESH_THRESHOLD
103+
)
104+
if is_stale:
105+
return TokenState.STALE
106+
107+
return TokenState.FRESH
85108

86109

87110
def generate_cert(

tests/unit/test_refresh_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import asyncio
1718
import datetime
1819

1920
from conftest import SCOPES # type: ignore
2021
import google.auth
2122
from google.auth.credentials import Credentials
23+
from google.auth.credentials import TokenState
2224
import google.oauth2.credentials
2325
from mock import Mock
2426
from mock import patch
@@ -32,7 +34,7 @@
3234
@pytest.fixture
3335
def credentials() -> Credentials:
3436
credentials = Mock(spec=Credentials)
35-
credentials.valid = True
37+
credentials.token_state = TokenState.FRESH
3638
credentials.token = "12345"
3739
return credentials
3840

0 commit comments

Comments
 (0)