Skip to content

Commit 86b2267

Browse files
committed
feat: auto-detect ASGI mode for @aio decorated functions
Functions decorated with @aio.http or @aio.cloud_event now automatically run in ASGI mode without requiring the --asgi flag. This improves the developer experience by removing the need to remember to pass the flag when using async decorators.
1 parent 58deaf1 commit 86b2267

File tree

5 files changed

+99
-1
lines changed

5 files changed

+99
-1
lines changed

src/functions_framework/_cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import click
1818

19-
from functions_framework import create_app
19+
from functions_framework import create_app, _function_registry
2020
from functions_framework._http import create_server
2121

2222

@@ -39,6 +39,9 @@
3939
help="Use ASGI server for function execution",
4040
)
4141
def _cli(target, source, signature_type, host, port, debug, asgi):
42+
if not asgi and target in _function_registry.ASGI_FUNCTIONS:
43+
asgi = True
44+
4245
if asgi: # pragma: no cover
4346
from functions_framework.aio import create_asgi_app
4447

src/functions_framework/_function_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
# Keys are the user function name, values are the type of the function input
4141
INPUT_TYPE_MAP = {}
4242

43+
# ASGI_FUNCTIONS stores function names that require ASGI mode.
44+
# Functions decorated with @aio.http or @aio.cloud_event are added here.
45+
ASGI_FUNCTIONS = set()
46+
4347

4448
def get_user_function(source, source_module, target):
4549
"""Returns user function, raises exception for invalid function."""

src/functions_framework/aio/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def cloud_event(func: CloudEventFunction) -> CloudEventFunction:
6969
_function_registry.REGISTRY_MAP[func.__name__] = (
7070
_function_registry.CLOUDEVENT_SIGNATURE_TYPE
7171
)
72+
_function_registry.ASGI_FUNCTIONS.add(func.__name__)
7273
if inspect.iscoroutinefunction(func):
7374

7475
@functools.wraps(func)
@@ -89,6 +90,7 @@ def http(func: HTTPFunction) -> HTTPFunction:
8990
_function_registry.REGISTRY_MAP[func.__name__] = (
9091
_function_registry.HTTP_SIGNATURE_TYPE
9192
)
93+
_function_registry.ASGI_FUNCTIONS.add(func.__name__)
9294

9395
if inspect.iscoroutinefunction(func):
9496

tests/test_cli.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from click.testing import CliRunner
2121

2222
import functions_framework
23+
import functions_framework._function_registry as _function_registry
24+
import functions_framework.aio
2325

2426
from functions_framework._cli import _cli
2527

@@ -124,3 +126,49 @@ def test_asgi_cli(monkeypatch):
124126
assert result.exit_code == 0
125127
assert create_asgi_app.calls == [pretend.call("foo", None, "http")]
126128
assert asgi_server.run.calls == [pretend.call("0.0.0.0", 8080)]
129+
130+
131+
def test_auto_asgi_for_aio_decorated_functions(monkeypatch):
132+
original_asgi_functions = _function_registry.ASGI_FUNCTIONS.copy()
133+
_function_registry.ASGI_FUNCTIONS.clear()
134+
_function_registry.ASGI_FUNCTIONS.add("my_aio_func")
135+
136+
asgi_app = pretend.stub()
137+
create_asgi_app = pretend.call_recorder(lambda *a, **k: asgi_app)
138+
aio_module = pretend.stub(create_asgi_app=create_asgi_app)
139+
monkeypatch.setitem(sys.modules, "functions_framework.aio", aio_module)
140+
141+
asgi_server = pretend.stub(run=pretend.call_recorder(lambda host, port: None))
142+
create_server = pretend.call_recorder(lambda app, debug: asgi_server)
143+
monkeypatch.setattr(functions_framework._cli, "create_server", create_server)
144+
145+
runner = CliRunner()
146+
result = runner.invoke(_cli, ["--target", "my_aio_func"])
147+
148+
assert create_asgi_app.calls == [pretend.call("my_aio_func", None, "http")]
149+
assert asgi_server.run.calls == [pretend.call("0.0.0.0", 8080)]
150+
151+
_function_registry.ASGI_FUNCTIONS.clear()
152+
_function_registry.ASGI_FUNCTIONS.update(original_asgi_functions)
153+
154+
155+
def test_no_auto_asgi_for_regular_functions(monkeypatch):
156+
original_asgi_functions = _function_registry.ASGI_FUNCTIONS.copy()
157+
_function_registry.ASGI_FUNCTIONS.clear()
158+
159+
app = pretend.stub()
160+
create_app = pretend.call_recorder(lambda *a, **k: app)
161+
monkeypatch.setattr(functions_framework._cli, "create_app", create_app)
162+
163+
flask_server = pretend.stub(run=pretend.call_recorder(lambda host, port: None))
164+
create_server = pretend.call_recorder(lambda app, debug: flask_server)
165+
monkeypatch.setattr(functions_framework._cli, "create_server", create_server)
166+
167+
runner = CliRunner()
168+
result = runner.invoke(_cli, ["--target", "regular_func"])
169+
170+
assert create_app.calls == [pretend.call("regular_func", None, "http")]
171+
assert flask_server.run.calls == [pretend.call("0.0.0.0", 8080)]
172+
173+
_function_registry.ASGI_FUNCTIONS.clear()
174+
_function_registry.ASGI_FUNCTIONS.update(original_asgi_functions)

tests/test_decorator_functions.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from cloudevents import conversion as ce_conversion
2020
from cloudevents.http import CloudEvent
21+
import functions_framework._function_registry as registry
2122

2223
# Conditional import for Starlette
2324
if sys.version_info >= (3, 8):
@@ -128,3 +129,43 @@ def test_aio_http_dict_response():
128129
resp = client.post("/")
129130
assert resp.status_code == 200
130131
assert resp.json() == {"message": "hello", "count": 42, "success": True}
132+
133+
134+
def test_aio_decorators_register_asgi_functions():
135+
"""Test that @aio decorators add function names to ASGI_FUNCTIONS registry."""
136+
original_registry_map = registry.REGISTRY_MAP.copy()
137+
original_asgi_functions = registry.ASGI_FUNCTIONS.copy()
138+
registry.REGISTRY_MAP.clear()
139+
registry.ASGI_FUNCTIONS.clear()
140+
141+
from functions_framework.aio import http, cloud_event
142+
143+
@http
144+
async def test_http_func(request):
145+
return "test"
146+
147+
@cloud_event
148+
async def test_cloud_event_func(event):
149+
pass
150+
151+
assert "test_http_func" in registry.ASGI_FUNCTIONS
152+
assert "test_cloud_event_func" in registry.ASGI_FUNCTIONS
153+
154+
assert registry.REGISTRY_MAP["test_http_func"] == "http"
155+
assert registry.REGISTRY_MAP["test_cloud_event_func"] == "cloudevent"
156+
157+
@http
158+
def test_http_sync(request):
159+
return "sync"
160+
161+
@cloud_event
162+
def test_cloud_event_sync(event):
163+
pass
164+
165+
assert "test_http_sync" in registry.ASGI_FUNCTIONS
166+
assert "test_cloud_event_sync" in registry.ASGI_FUNCTIONS
167+
168+
registry.REGISTRY_MAP.clear()
169+
registry.REGISTRY_MAP.update(original_registry_map)
170+
registry.ASGI_FUNCTIONS.clear()
171+
registry.ASGI_FUNCTIONS.update(original_asgi_functions)

0 commit comments

Comments
 (0)