Skip to content

Commit 9331ba8

Browse files
committed
fix(functions): Move credential refresh to functions service init
1 parent 6c8f2f2 commit 9331ba8

File tree

3 files changed

+68
-31
lines changed

3 files changed

+68
-31
lines changed

firebase_admin/functions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
from dataclasses import dataclass
2525

2626
from google.auth.compute_engine import Credentials as ComputeEngineCredentials
27+
from google.auth.exceptions import RefreshError
2728
from google.auth.transport import requests as google_auth_requests
29+
2830
import requests
2931
import firebase_admin
3032
from firebase_admin import App
@@ -101,6 +103,12 @@ def __init__(self, app: App):
101103
'GOOGLE_CLOUD_PROJECT environment variable.')
102104

103105
self._credential = app.credential.get_credential()
106+
try:
107+
# Refresh the credential to ensure all attributes (e.g. service_account_email)
108+
# are populated, preventing cold start errors.
109+
self._credential.refresh(google_auth_requests.Request())
110+
except RefreshError as err:
111+
raise ValueError(f'Initial credential refresh failed: {err}') from err
104112
self._http_client = _http_client.JsonHttpClient(credential=self._credential)
105113

106114
def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue:
@@ -290,7 +298,6 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str
290298
# Meaning that it's credential should be a Compute Engine Credential.
291299
if _Validators.is_non_empty_string(extension_id) and \
292300
isinstance(self._credential, ComputeEngineCredentials):
293-
self._credential.refresh(google_auth_requests.Request())
294301
id_token = self._credential.token
295302
task.http_request['headers'] = \
296303
{**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'}

tests/test_functions.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from datetime import datetime, timedelta, timezone
1818
import json
1919
import time
20-
from unittest import mock
2120
import pytest
2221

2322
import firebase_admin
@@ -125,6 +124,10 @@ def test_task_enqueue(self):
125124
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
126125
assert task_id == 'test-task-id'
127126

127+
task = json.loads(recorder[0].body.decode())['task']
128+
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
129+
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
130+
128131
def test_task_enqueue_with_extension(self):
129132
resource_name = (
130133
'projects/test-project/locations/us-central1/queues/'
@@ -143,46 +146,68 @@ def test_task_enqueue_with_extension(self):
143146
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
144147
assert task_id == 'test-task-id'
145148

146-
def test_task_delete(self):
147-
_, recorder = self._instrument_functions_service()
148-
queue = functions.task_queue('test-function-name')
149-
queue.delete('test-task-id')
149+
task = json.loads(recorder[0].body.decode())['task']
150+
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
151+
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
152+
153+
def test_task_enqueue_compute_engine(self):
154+
app = firebase_admin.initialize_app(
155+
testutils.MockComputeEngineCredential(),
156+
options={'projectId': 'test-project'},
157+
name='test-project-gce')
158+
_, recorder = self._instrument_functions_service(app)
159+
queue = functions.task_queue('test-function-name', app=app)
160+
task_id = queue.enqueue(_DEFAULT_DATA)
150161
assert len(recorder) == 1
151-
assert recorder[0].method == 'DELETE'
152-
assert recorder[0].url == _DEFAULT_TASK_URL
153-
expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag'
162+
assert recorder[0].method == 'POST'
163+
assert recorder[0].url == _DEFAULT_REQUEST_URL
164+
assert recorder[0].headers['Content-Type'] == 'application/json'
165+
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
166+
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
154167
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
168+
assert task_id == 'test-task-id'
155169

156-
@mock.patch('firebase_admin.functions.isinstance')
157-
def test_task_enqueue_with_extension_refreshes_credential(self, mock_isinstance):
158-
# Force the code to take the ComputeEngineCredentials path
159-
mock_isinstance.return_value = True
170+
task = json.loads(recorder[0].body.decode())['task']
171+
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'}
172+
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
160173

161-
# Create a custom response with the extension ID in the resource name
174+
def test_task_enqueue_with_extension_compute_engine(self):
162175
resource_name = (
163176
'projects/test-project/locations/us-central1/queues/'
164177
'ext-test-extension-id-test-function-name/tasks'
165178
)
166179
extension_response = json.dumps({'name': resource_name + '/test-task-id'})
180+
app = firebase_admin.initialize_app(
181+
testutils.MockComputeEngineCredential(),
182+
options={'projectId': 'test-project'},
183+
name='test-project-gce-extensions')
184+
_, recorder = self._instrument_functions_service(app, payload=extension_response)
185+
queue = functions.task_queue('test-function-name', 'test-extension-id', app)
186+
task_id = queue.enqueue(_DEFAULT_DATA)
187+
assert len(recorder) == 1
188+
assert recorder[0].method == 'POST'
189+
assert recorder[0].url == _CLOUD_TASKS_URL + resource_name
190+
assert recorder[0].headers['Content-Type'] == 'application/json'
191+
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
192+
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
193+
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
194+
assert task_id == 'test-task-id'
167195

168-
# Instrument the service and get the underlying credential mock
169-
functions_service, recorder = self._instrument_functions_service(payload=extension_response)
170-
mock_credential = functions_service._credential
171-
mock_credential.token = 'mock-id-token'
172-
mock_credential.refresh = mock.MagicMock()
173-
174-
# Create a TaskQueue with an extension ID
175-
queue = functions_service.task_queue('test-function-name', 'test-extension-id')
176-
177-
# Enqueue a task
178-
queue.enqueue(_DEFAULT_DATA)
179-
180-
# Assert that the credential was refreshed
181-
mock_credential.refresh.assert_called_once()
196+
task = json.loads(recorder[0].body.decode())['task']
197+
assert 'oidc_token' not in task['http_request']
198+
assert task['http_request']['headers'] == {
199+
'Content-Type': 'application/json',
200+
'Authorization': 'Bearer mock-compute-engine-token'}
182201

183-
# Assert that the correct token was used in the header
202+
def test_task_delete(self):
203+
_, recorder = self._instrument_functions_service()
204+
queue = functions.task_queue('test-function-name')
205+
queue.delete('test-task-id')
184206
assert len(recorder) == 1
185-
assert recorder[0].headers['Authorization'] == 'Bearer mock-id-token'
207+
assert recorder[0].method == 'DELETE'
208+
assert recorder[0].url == _DEFAULT_TASK_URL
209+
expected_metrics_header = _utils.get_metrics_header() + ' mock-cred-metric-tag'
210+
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
186211

187212
class TestTaskQueueOptions:
188213

tests/testutils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,11 @@ class MockGoogleCredential(credentials.Credentials):
118118
"""A mock Google authentication credential."""
119119
def refresh(self, request):
120120
self.token = 'mock-token'
121+
self._service_account_email = "mock-email"
121122

122123
@property
123124
def service_account_email(self):
124-
return 'mock-email'
125+
return self._service_account_email
125126

126127
# Simulate x-goog-api-client modification in credential refresh
127128
def _metric_header_for_usage(self):
@@ -141,6 +142,10 @@ class MockGoogleComputeEngineCredential(compute_engine.Credentials):
141142
"""A mock Compute Engine credential"""
142143
def refresh(self, request):
143144
self.token = 'mock-compute-engine-token'
145+
self._service_account_email = 'mock-gce-email'
146+
147+
def _metric_header_for_usage(self):
148+
return 'mock-gce-cred-metric-tag'
144149

145150
class MockComputeEngineCredential(firebase_admin.credentials.Base):
146151
"""A mock Firebase credential implementation."""

0 commit comments

Comments
 (0)