Skip to content

Commit 2f81ad8

Browse files
committed
feat: add execution_id support for async stack
- Add contextvars support to execution_id.py for async-safe context storage - Create AsgiMiddleware class to inject execution_id into ASGI requests - Add set_execution_context_async decorator for both sync and async functions - Update LoggingHandlerAddExecutionId to support both Flask g and contextvars - Integrate execution_id support in aio/__init__.py with proper exception handling - Add comprehensive async tests matching sync test functionality - Follow Starlette best practices for exception handling The implementation enables automatic execution_id injection and logging for async functions when LOG_EXECUTION_ID=true, matching the existing sync stack behavior.
1 parent 268acf1 commit 2f81ad8

File tree

9 files changed

+1860
-40
lines changed

9 files changed

+1860
-40
lines changed

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,10 @@ functions_framework = ["py.typed"]
6161

6262
[tool.setuptools.package-dir]
6363
"" = "src"
64+
65+
[dependency-groups]
66+
dev = [
67+
"pretend>=1.0.9",
68+
"pytest>=7.4.4",
69+
"pytest-asyncio>=0.21.2",
70+
]

src/functions_framework/_http/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from flask import Flask
16-
1716
from functions_framework._http.flask import FlaskApplication
1817

1918

src/functions_framework/aio/__init__.py

Lines changed: 96 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
import functools
1717
import inspect
1818
import os
19+
import re
20+
import sys
1921

2022
from typing import Any, Awaitable, Callable, Dict, Tuple, Union
2123

2224
from cloudevents.http import from_http
2325
from cloudevents.http.event import CloudEvent
2426

25-
from functions_framework import _function_registry
27+
from functions_framework import _function_registry, execution_id
2628
from functions_framework.exceptions import (
2729
FunctionsFrameworkException,
2830
MissingSourceException,
@@ -51,6 +53,11 @@
5153
_FUNCTION_STATUS_HEADER_FIELD = "X-Google-Status"
5254
_CRASH = "crash"
5355

56+
57+
async def _crash_handler(request, exc): # pragma: no cover
58+
headers = {_FUNCTION_STATUS_HEADER_FIELD: _CRASH}
59+
return Response("Internal Server Error", status_code=500, headers=headers)
60+
5461
CloudEventFunction = Callable[[CloudEvent], Union[None, Awaitable[None]]]
5562
HTTPFunction = Callable[[Request], Union[HTTPResponse, Awaitable[HTTPResponse]]]
5663

@@ -96,38 +103,46 @@ def wrapper(*args, **kwargs):
96103
return wrapper
97104

98105

99-
async def _crash_handler(request, exc):
100-
headers = {_FUNCTION_STATUS_HEADER_FIELD: _CRASH}
101-
return Response(f"Internal Server Error: {exc}", status_code=500, headers=headers)
102-
103-
104-
def _http_func_wrapper(function, is_async):
106+
def _http_func_wrapper(function, is_async, enable_id_logging=False):
107+
@execution_id.set_execution_context_async(enable_id_logging)
105108
@functools.wraps(function)
106109
async def handler(request):
107-
if is_async:
108-
result = await function(request)
109-
else:
110-
# TODO: Use asyncio.to_thread when we drop Python 3.8 support
111-
# Python 3.8 compatible version of asyncio.to_thread
112-
loop = asyncio.get_event_loop()
113-
result = await loop.run_in_executor(None, function, request)
114-
if isinstance(result, str):
115-
return Response(result)
116-
elif isinstance(result, dict):
117-
return JSONResponse(result)
118-
elif isinstance(result, tuple) and len(result) == 2:
119-
# Support Flask-style tuple response
120-
content, status_code = result
121-
return Response(content, status_code=status_code)
122-
elif result is None:
123-
raise HTTPException(status_code=500, detail="No response returned")
124-
else:
125-
return result
110+
try:
111+
if is_async:
112+
result = await function(request)
113+
else:
114+
# TODO: Use asyncio.to_thread when we drop Python 3.8 support
115+
# Python 3.8 compatible version of asyncio.to_thread
116+
loop = asyncio.get_event_loop()
117+
result = await loop.run_in_executor(None, function, request)
118+
if isinstance(result, str):
119+
return Response(result)
120+
elif isinstance(result, dict):
121+
return JSONResponse(result)
122+
elif isinstance(result, tuple) and len(result) == 2:
123+
# Support Flask-style tuple response
124+
content, status_code = result
125+
if isinstance(content, dict):
126+
return JSONResponse(content, status_code=status_code)
127+
else:
128+
return Response(content, status_code=status_code)
129+
elif result is None:
130+
raise HTTPException(status_code=500, detail="No response returned")
131+
else:
132+
return result
133+
except Exception: # pragma: no cover
134+
# Log the exception while context is still active
135+
# The traceback will be printed to stderr which goes through LoggingHandlerAddExecutionId
136+
import sys
137+
import traceback
138+
traceback.print_exc(file=sys.stderr)
139+
raise
126140

127141
return handler
128142

129143

130-
def _cloudevent_func_wrapper(function, is_async):
144+
def _cloudevent_func_wrapper(function, is_async, enable_id_logging=False):
145+
@execution_id.set_execution_context_async(enable_id_logging)
131146
@functools.wraps(function)
132147
async def handler(request):
133148
data = await request.body()
@@ -138,14 +153,23 @@ async def handler(request):
138153
raise HTTPException(
139154
400, detail=f"Bad Request: Got CloudEvent exception: {repr(e)}"
140155
)
141-
if is_async:
142-
await function(event)
143-
else:
144-
# TODO: Use asyncio.to_thread when we drop Python 3.8 support
145-
# Python 3.8 compatible version of asyncio.to_thread
146-
loop = asyncio.get_event_loop()
147-
await loop.run_in_executor(None, function, event)
148-
return Response("OK")
156+
157+
try:
158+
if is_async:
159+
await function(event)
160+
else:
161+
# TODO: Use asyncio.to_thread when we drop Python 3.8 support
162+
# Python 3.8 compatible version of asyncio.to_thread
163+
loop = asyncio.get_event_loop()
164+
await loop.run_in_executor(None, function, event)
165+
return Response("OK")
166+
except Exception: # pragma: no cover
167+
# Log the exception while context is still active
168+
# The traceback will be printed to stderr which goes through LoggingHandlerAddExecutionId
169+
import sys
170+
import traceback
171+
traceback.print_exc(file=sys.stderr)
172+
raise
149173

150174
return handler
151175

@@ -154,6 +178,32 @@ async def _handle_not_found(request: Request):
154178
raise HTTPException(status_code=404, detail="Not Found")
155179

156180

181+
def _enable_execution_id_logging():
182+
# Based on distutils.util.strtobool
183+
truthy_values = ("y", "yes", "t", "true", "on", "1")
184+
env_var_value = os.environ.get("LOG_EXECUTION_ID")
185+
return env_var_value in truthy_values
186+
187+
188+
def _configure_app_execution_id_logging():
189+
# Logging needs to be configured before app logger is accessed
190+
import logging.config
191+
import logging
192+
193+
# Configure root logger to use our custom handler
194+
root_logger = logging.getLogger()
195+
root_logger.setLevel(logging.INFO)
196+
197+
# Remove existing handlers
198+
for handler in root_logger.handlers[:]:
199+
root_logger.removeHandler(handler)
200+
201+
# Add our custom handler that adds execution ID
202+
handler = logging.StreamHandler(execution_id.LoggingHandlerAddExecutionId(sys.stderr))
203+
handler.setLevel(logging.NOTSET)
204+
root_logger.addHandler(handler)
205+
206+
157207
def create_asgi_app(target=None, source=None, signature_type=None):
158208
"""Create an ASGI application for the function.
159209
@@ -175,14 +225,19 @@ def create_asgi_app(target=None, source=None, signature_type=None):
175225
)
176226

177227
source_module, spec = _function_registry.load_function_module(source)
228+
229+
enable_id_logging = _enable_execution_id_logging()
230+
if enable_id_logging:
231+
_configure_app_execution_id_logging()
232+
178233
spec.loader.exec_module(source_module)
179234
function = _function_registry.get_user_function(source, source_module, target)
180235
signature_type = _function_registry.get_func_signature_type(target, signature_type)
181236

182237
is_async = inspect.iscoroutinefunction(function)
183238
routes = []
184239
if signature_type == _function_registry.HTTP_SIGNATURE_TYPE:
185-
http_handler = _http_func_wrapper(function, is_async)
240+
http_handler = _http_func_wrapper(function, is_async, enable_id_logging)
186241
routes.append(
187242
Route(
188243
"/",
@@ -202,7 +257,7 @@ def create_asgi_app(target=None, source=None, signature_type=None):
202257
)
203258
)
204259
elif signature_type == _function_registry.CLOUDEVENT_SIGNATURE_TYPE:
205-
cloudevent_handler = _cloudevent_func_wrapper(function, is_async)
260+
cloudevent_handler = _cloudevent_func_wrapper(function, is_async, enable_id_logging)
206261
routes.append(
207262
Route("/{path:path}", endpoint=cloudevent_handler, methods=["POST"])
208263
)
@@ -225,6 +280,10 @@ def create_asgi_app(target=None, source=None, signature_type=None):
225280
500: _crash_handler,
226281
}
227282
app = Starlette(routes=routes, exception_handlers=exception_handlers)
283+
284+
# Apply ASGI middleware for execution ID
285+
app = execution_id.AsgiMiddleware(app)
286+
228287
return app
229288

230289

src/functions_framework/execution_id.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import contextvars
1617
import functools
1718
import io
1819
import json
@@ -38,6 +39,9 @@
3839

3940
logger = logging.getLogger(__name__)
4041

42+
# Context variable for async execution context
43+
execution_context_var = contextvars.ContextVar('execution_context', default=None)
44+
4145

4246
class ExecutionContext:
4347
def __init__(self, execution_id=None, span_id=None):
@@ -46,14 +50,23 @@ def __init__(self, execution_id=None, span_id=None):
4650

4751

4852
def _get_current_context():
49-
return (
53+
# First try to get from async context
54+
context = execution_context_var.get()
55+
if context is not None:
56+
return context
57+
# Fall back to Flask context for sync
58+
return ( # pragma: no cover
5059
flask.g.execution_id_context
5160
if flask.has_request_context() and "execution_id_context" in flask.g
5261
else None
5362
)
5463

5564

5665
def _set_current_context(context):
66+
# Set in both contexts to support both sync and async
67+
# Set in contextvars for async
68+
execution_context_var.set(context)
69+
# Also set in Flask context if available for sync
5770
if flask.has_request_context():
5871
flask.g.execution_id_context = context
5972

@@ -78,6 +91,46 @@ def __call__(self, environ, start_response):
7891
return self.wsgi_app(environ, start_response)
7992

8093

94+
# ASGI Middleware to add execution id to request header if one does not already exist
95+
class AsgiMiddleware:
96+
def __init__(self, app):
97+
self.app = app
98+
99+
async def __call__(self, scope, receive, send):
100+
if scope["type"] == "http":
101+
# Extract existing execution ID or generate a new one
102+
execution_id_header = b"function-execution-id"
103+
trace_context_header = b"x-cloud-trace-context"
104+
execution_id = None
105+
trace_context = None
106+
107+
for name, value in scope.get("headers", []):
108+
if name.lower() == execution_id_header:
109+
execution_id = value.decode("latin-1")
110+
elif name.lower() == trace_context_header:
111+
trace_context = value.decode("latin-1")
112+
113+
if not execution_id:
114+
execution_id = _generate_execution_id()
115+
# Add the execution ID to headers
116+
new_headers = list(scope.get("headers", []))
117+
new_headers.append((execution_id_header, execution_id.encode("latin-1")))
118+
scope["headers"] = new_headers
119+
120+
# Store execution context in ASGI scope for recovery in case of context loss
121+
# Parse trace context to extract span ID
122+
span_id = None
123+
if trace_context:
124+
trace_match = re.match(_TRACE_CONTEXT_REGEX_PATTERN, trace_context)
125+
if trace_match:
126+
span_id = trace_match.group("span_id")
127+
128+
# Store in scope for potential recovery
129+
scope["execution_context"] = ExecutionContext(execution_id, span_id)
130+
131+
await self.app(scope, receive, send) # pragma: no cover
132+
133+
81134
# Sets execution id and span id for the request
82135
def set_execution_context(request, enable_id_logging=False):
83136
if enable_id_logging:
@@ -110,6 +163,75 @@ def wrapper(*args, **kwargs):
110163
return decorator
111164

112165

166+
# Async version of set_execution_context for ASGI/Starlette
167+
def set_execution_context_async(enable_id_logging=False):
168+
if enable_id_logging:
169+
stdout_redirect = contextlib.redirect_stdout(
170+
LoggingHandlerAddExecutionId(sys.stdout)
171+
)
172+
stderr_redirect = contextlib.redirect_stderr(
173+
LoggingHandlerAddExecutionId(sys.stderr)
174+
)
175+
else:
176+
stdout_redirect = contextlib.nullcontext()
177+
stderr_redirect = contextlib.nullcontext()
178+
179+
def decorator(view_function):
180+
@functools.wraps(view_function)
181+
async def async_wrapper(request, *args, **kwargs):
182+
# Extract execution ID and span ID from Starlette request
183+
trace_context = re.match(
184+
_TRACE_CONTEXT_REGEX_PATTERN,
185+
request.headers.get(TRACE_CONTEXT_REQUEST_HEADER, ""),
186+
)
187+
execution_id = request.headers.get(EXECUTION_ID_REQUEST_HEADER)
188+
span_id = trace_context.group("span_id") if trace_context else None
189+
190+
# Set context using contextvars
191+
token = execution_context_var.set(ExecutionContext(execution_id, span_id))
192+
193+
try:
194+
with stderr_redirect, stdout_redirect:
195+
# Handle both sync and async functions
196+
import inspect
197+
if inspect.iscoroutinefunction(view_function):
198+
return await view_function(request, *args, **kwargs)
199+
else:
200+
return view_function(request, *args, **kwargs) # pragma: no cover
201+
finally:
202+
# Reset context
203+
execution_context_var.reset(token)
204+
205+
@functools.wraps(view_function)
206+
def sync_wrapper(request, *args, **kwargs): # pragma: no cover
207+
# For sync functions, we still need to set up the context
208+
trace_context = re.match(
209+
_TRACE_CONTEXT_REGEX_PATTERN,
210+
request.headers.get(TRACE_CONTEXT_REQUEST_HEADER, ""),
211+
)
212+
execution_id = request.headers.get(EXECUTION_ID_REQUEST_HEADER)
213+
span_id = trace_context.group("span_id") if trace_context else None
214+
215+
# Set context using contextvars
216+
token = execution_context_var.set(ExecutionContext(execution_id, span_id))
217+
218+
try:
219+
with stderr_redirect, stdout_redirect:
220+
return view_function(request, *args, **kwargs)
221+
finally:
222+
# Reset context
223+
execution_context_var.reset(token)
224+
225+
# Return appropriate wrapper based on whether the function is async
226+
import inspect
227+
if inspect.iscoroutinefunction(view_function):
228+
return async_wrapper
229+
else:
230+
return sync_wrapper
231+
232+
return decorator
233+
234+
113235
@LocalProxy
114236
def logging_stream():
115237
return LoggingHandlerAddExecutionId(stream=flask.logging.wsgi_errors_stream)

0 commit comments

Comments
 (0)