Skip to content

Commit cee4759

Browse files
committed
fix(functions): Moved credential refresh to run on task payload update with freshness guard
1 parent 9331ba8 commit cee4759

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

firebase_admin/functions.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from dataclasses import dataclass
2525

2626
from google.auth.compute_engine import Credentials as ComputeEngineCredentials
27+
from google.auth.credentials import TokenState
2728
from google.auth.exceptions import RefreshError
2829
from google.auth.transport import requests as google_auth_requests
2930

@@ -103,12 +104,6 @@ def __init__(self, app: App):
103104
'GOOGLE_CLOUD_PROJECT environment variable.')
104105

105106
self._credential = app.credential.get_credential()
106-
try:
107-
# Refresh the credential to ensure all attributes (e.g. service_account_email)
108-
# are populated, preventing cold start errors.
109-
self._credential.refresh(google_auth_requests.Request())
110-
except RefreshError as err:
111-
raise ValueError(f'Initial credential refresh failed: {err}') from err
112107
self._http_client = _http_client.JsonHttpClient(credential=self._credential)
113108

114109
def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue:
@@ -294,6 +289,15 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str
294289
# Get function url from task or generate from resources
295290
if not _Validators.is_non_empty_string(task.http_request['url']):
296291
task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT)
292+
293+
# Refresh the credential to ensure all attributes (e.g. service_account_email, id_token)
294+
# are populated, preventing cold start errors.
295+
if self._credential.token_state != TokenState.FRESH:
296+
try:
297+
self._credential.refresh(google_auth_requests.Request())
298+
except RefreshError as err:
299+
raise ValueError(f'Initial task payload credential refresh failed: {err}') from err
300+
297301
# If extension id is provided, it emplies that it is being run from a deployed extension.
298302
# Meaning that it's credential should be a Compute Engine Credential.
299303
if _Validators.is_non_empty_string(extension_id) and \

tests/testutils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,21 @@ def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ
116116
# pylint: disable=abstract-method
117117
class MockGoogleCredential(credentials.Credentials):
118118
"""A mock Google authentication credential."""
119+
120+
def __init__(self):
121+
super().__init__()
122+
self.token = None
123+
self._service_account_email = None
124+
self._token_state = credentials.TokenState.INVALID
125+
119126
def refresh(self, request):
120127
self.token = 'mock-token'
121-
self._service_account_email = "mock-email"
128+
self._service_account_email = 'mock-email'
129+
self._token_state = credentials.TokenState.FRESH
130+
131+
@property
132+
def token_state(self):
133+
return self._token_state
122134

123135
@property
124136
def service_account_email(self):
@@ -140,9 +152,21 @@ def get_credential(self):
140152

141153
class MockGoogleComputeEngineCredential(compute_engine.Credentials):
142154
"""A mock Compute Engine credential"""
155+
156+
def __init__(self):
157+
super().__init__()
158+
self.token = None
159+
self._service_account_email = None
160+
self._token_state = credentials.TokenState.INVALID
161+
143162
def refresh(self, request):
144163
self.token = 'mock-compute-engine-token'
145164
self._service_account_email = 'mock-gce-email'
165+
self._token_state = credentials.TokenState.FRESH
166+
167+
@property
168+
def token_state(self):
169+
return self._token_state
146170

147171
def _metric_header_for_usage(self):
148172
return 'mock-gce-cred-metric-tag'

0 commit comments

Comments
 (0)