Skip to content

Commit 0fb269c

Browse files
committed
Add support for ASGI lifespan
Signed-off-by: Anuraag Agrawal <[email protected]>
1 parent 6be07d9 commit 0fb269c

File tree

6 files changed

+287
-47
lines changed

6 files changed

+287
-47
lines changed

conformance/test/gen/connectrpc/conformance/v1/service_connect.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT!
22
# source: connectrpc/conformance/v1/service.proto
33

4-
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
4+
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping
55
from typing import Protocol
66

77
from connectrpc.client import ConnectClient, ConnectClientSync
@@ -72,16 +72,17 @@ async def idempotent_unary(
7272
raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")
7373

7474

75-
class ConformanceServiceASGIApplication(ConnectASGIApplication):
75+
class ConformanceServiceASGIApplication(ConnectASGIApplication[ConformanceService]):
7676
def __init__(
7777
self,
78-
service: ConformanceService,
78+
service: ConformanceService | AsyncGenerator[ConformanceService, None],
7979
*,
8080
interceptors: Iterable[Interceptor] = (),
8181
read_max_bytes: int | None = None,
8282
) -> None:
8383
super().__init__(
84-
endpoints={
84+
service=service,
85+
endpoints=lambda svc: {
8586
"/connectrpc.conformance.v1.ConformanceService/Unary": Endpoint.unary(
8687
method=MethodInfo(
8788
name="Unary",
@@ -90,7 +91,7 @@ def __init__(
9091
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnaryResponse,
9192
idempotency_level=IdempotencyLevel.UNKNOWN,
9293
),
93-
function=service.unary,
94+
function=svc.unary,
9495
),
9596
"/connectrpc.conformance.v1.ConformanceService/ServerStream": Endpoint.server_stream(
9697
method=MethodInfo(
@@ -100,7 +101,7 @@ def __init__(
100101
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ServerStreamResponse,
101102
idempotency_level=IdempotencyLevel.UNKNOWN,
102103
),
103-
function=service.server_stream,
104+
function=svc.server_stream,
104105
),
105106
"/connectrpc.conformance.v1.ConformanceService/ClientStream": Endpoint.client_stream(
106107
method=MethodInfo(
@@ -110,7 +111,7 @@ def __init__(
110111
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.ClientStreamResponse,
111112
idempotency_level=IdempotencyLevel.UNKNOWN,
112113
),
113-
function=service.client_stream,
114+
function=svc.client_stream,
114115
),
115116
"/connectrpc.conformance.v1.ConformanceService/BidiStream": Endpoint.bidi_stream(
116117
method=MethodInfo(
@@ -120,7 +121,7 @@ def __init__(
120121
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.BidiStreamResponse,
121122
idempotency_level=IdempotencyLevel.UNKNOWN,
122123
),
123-
function=service.bidi_stream,
124+
function=svc.bidi_stream,
124125
),
125126
"/connectrpc.conformance.v1.ConformanceService/Unimplemented": Endpoint.unary(
126127
method=MethodInfo(
@@ -130,7 +131,7 @@ def __init__(
130131
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.UnimplementedResponse,
131132
idempotency_level=IdempotencyLevel.UNKNOWN,
132133
),
133-
function=service.unimplemented,
134+
function=svc.unimplemented,
134135
),
135136
"/connectrpc.conformance.v1.ConformanceService/IdempotentUnary": Endpoint.unary(
136137
method=MethodInfo(
@@ -140,7 +141,7 @@ def __init__(
140141
output=connectrpc_dot_conformance_dot_v1_dot_service__pb2.IdempotentUnaryResponse,
141142
idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS,
142143
),
143-
function=service.idempotent_unary,
144+
function=svc.idempotent_unary,
144145
),
145146
},
146147
interceptors=interceptors,

example/example/eliza_connect.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT!
22
# source: example/eliza.proto
33

4-
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
4+
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator, Mapping
55
from typing import Protocol
66

77
from connectrpc.client import ConnectClient, ConnectClientSync
@@ -39,16 +39,17 @@ def introduce(
3939
raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")
4040

4141

42-
class ElizaServiceASGIApplication(ConnectASGIApplication):
42+
class ElizaServiceASGIApplication(ConnectASGIApplication[ElizaService]):
4343
def __init__(
4444
self,
45-
service: ElizaService,
45+
service: ElizaService | AsyncGenerator[ElizaService, None],
4646
*,
4747
interceptors: Iterable[Interceptor] = (),
4848
read_max_bytes: int | None = None,
4949
) -> None:
5050
super().__init__(
51-
endpoints={
51+
service=service,
52+
endpoints=lambda svc: {
5253
"/connectrpc.eliza.v1.ElizaService/Say": Endpoint.unary(
5354
method=MethodInfo(
5455
name="Say",
@@ -57,7 +58,7 @@ def __init__(
5758
output=example_dot_eliza__pb2.SayResponse,
5859
idempotency_level=IdempotencyLevel.NO_SIDE_EFFECTS,
5960
),
60-
function=service.say,
61+
function=svc.say,
6162
),
6263
"/connectrpc.eliza.v1.ElizaService/Converse": Endpoint.bidi_stream(
6364
method=MethodInfo(
@@ -67,7 +68,7 @@ def __init__(
6768
output=example_dot_eliza__pb2.ConverseResponse,
6869
idempotency_level=IdempotencyLevel.UNKNOWN,
6970
),
70-
function=service.converse,
71+
function=svc.converse,
7172
),
7273
"/connectrpc.eliza.v1.ElizaService/Introduce": Endpoint.server_stream(
7374
method=MethodInfo(
@@ -77,7 +78,7 @@ def __init__(
7778
output=example_dot_eliza__pb2.IntroduceResponse,
7879
idempotency_level=IdempotencyLevel.UNKNOWN,
7980
),
80-
function=service.introduce,
81+
function=svc.introduce,
8182
),
8283
},
8384
interceptors=interceptors,

protoc-gen-connect-python/generator/template.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ var ConnectTemplate = template.Must(template.New("ConnectTemplate").Parse(`# -*-
4444
# Generated by https://github.com/connectrpc/connect-python. DO NOT EDIT!
4545
# source: {{.FileName}}
4646
{{if .Services}}
47-
from collections.abc import AsyncIterator, Iterable, Iterator, Mapping
47+
from collections.abc import AsyncIterator, AsyncGenerator, Iterable, Iterator, Mapping
4848
from typing import Protocol
4949
5050
from connectrpc.client import ConnectClient, ConnectClientSync
@@ -67,10 +67,11 @@ class {{.Name}}(Protocol):{{- range .Methods }}
6767
raise ConnectError(Code.UNIMPLEMENTED, "Not implemented")
6868
{{ end }}
6969
70-
class {{.Name}}ASGIApplication(ConnectASGIApplication):
71-
def __init__(self, service: {{.Name}}, *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None:
70+
class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
71+
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}, None], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None:
7272
super().__init__(
73-
endpoints={ {{- range .Methods }}
73+
service=service,
74+
endpoints=lambda svc: { {{- range .Methods }}
7475
"/{{.ServiceName}}/{{.Name}}": Endpoint.{{.EndpointType}}(
7576
method=MethodInfo(
7677
name="{{.Name}}",
@@ -79,7 +80,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication):
7980
output={{.OutputType}},
8081
idempotency_level=IdempotencyLevel.{{.IdempotencyLevel}},
8182
),
82-
function=service.{{.PythonName}},
83+
function=svc.{{.PythonName}},
8384
),{{- end }}
8485
},
8586
interceptors=interceptors,

src/connectrpc/_server_async.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
import base64
22
import functools
3-
from abc import ABC, abstractmethod
3+
import inspect
4+
from abc import abstractmethod
45
from asyncio import CancelledError, sleep
5-
from collections.abc import AsyncIterator, Iterable, Mapping, Sequence
6+
from collections.abc import (
7+
AsyncGenerator,
8+
AsyncIterator,
9+
Callable,
10+
Iterable,
11+
Mapping,
12+
Sequence,
13+
)
614
from dataclasses import replace
715
from http import HTTPStatus
8-
from typing import TYPE_CHECKING, TypeVar
16+
from typing import TYPE_CHECKING, Generic, TypeVar, cast
917
from urllib.parse import parse_qs
1018

1119
from . import _compression, _server_shared
@@ -48,6 +56,7 @@
4856
Scope = "asgiref.typing.Scope"
4957

5058

59+
_SVC = TypeVar("_SVC")
5160
_REQ = TypeVar("_REQ")
5261
_RES = TypeVar("_RES")
5362

@@ -64,45 +73,99 @@
6473
)
6574

6675

67-
class ConnectASGIApplication(ABC):
76+
class ConnectASGIApplication(Generic[_SVC]):
6877
"""An ASGI application for the Connect protocol."""
6978

79+
_resolved_endpoints: Mapping[str, Endpoint] | None
80+
7081
@property
7182
@abstractmethod
7283
def path(self) -> str: ...
7384

7485
def __init__(
7586
self,
7687
*,
77-
endpoints: Mapping[str, Endpoint],
88+
service: _SVC | AsyncGenerator[_SVC, None],
89+
endpoints: Callable[[_SVC], Mapping[str, Endpoint]],
7890
interceptors: Iterable[Interceptor] = (),
7991
read_max_bytes: int | None = None,
8092
) -> None:
8193
"""Initialize the ASGI application."""
8294
super().__init__()
83-
if interceptors:
84-
interceptors = resolve_interceptors(interceptors)
85-
endpoints = {
86-
path: _apply_interceptors(endpoint, interceptors)
87-
for path, endpoint in endpoints.items()
88-
}
95+
self._service = service
8996
self._endpoints = endpoints
97+
self._interceptors = interceptors
98+
self._resolved_endpoints = None
9099
self._read_max_bytes = read_max_bytes
91100

92101
async def __call__(
93102
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
94103
) -> None:
95-
assert scope["type"] == "http" # noqa: S101 - only for type narrowing, in practice always true
104+
if scope["type"] == "websocket":
105+
msg = "connect does not support websockets"
106+
raise RuntimeError(msg)
107+
108+
if scope["type"] == "lifespan":
109+
service_iter = None
110+
while True:
111+
msg = await receive()
112+
match msg["type"]:
113+
case "lifespan.startup":
114+
# Need to cast since type checking doesn't seem to narrow well with isasyncgen
115+
if inspect.isasyncgen(self._service):
116+
service_iter = cast(
117+
"AsyncGenerator[_SVC, None]", self._service
118+
)
119+
try:
120+
service = await anext(service_iter)
121+
except Exception as e:
122+
await send(
123+
{
124+
"type": "lifespan.startup.failed",
125+
"message": str(e),
126+
}
127+
)
128+
return None
129+
else:
130+
service = cast("_SVC", self._service)
131+
if (state := scope.get("state")) is not None:
132+
state["endpoints"] = self._resolve_endpoints(service)
133+
await send({"type": "lifespan.startup.complete"})
134+
case "lifespan.shutdown":
135+
if service_iter is not None:
136+
try:
137+
await service_iter.aclose()
138+
except Exception as e:
139+
await send(
140+
{
141+
"type": "lifespan.shutdown.failed",
142+
"message": str(e),
143+
}
144+
)
145+
await send({"type": "lifespan.shutdown.complete"})
146+
return None
147+
148+
if state := scope.get("state"):
149+
endpoints: Mapping[str, Endpoint] = state["endpoints"]
150+
else:
151+
if not self._resolved_endpoints:
152+
if inspect.isasyncgen(self._service):
153+
msg = "ASGI server does not support lifespan but async generator passed for service. Enable lifespan support."
154+
raise RuntimeError(msg)
155+
156+
self._resolved_endpoints = self._resolve_endpoints(
157+
cast("_SVC", self._service)
158+
)
159+
endpoints = self._resolved_endpoints
96160

97161
ctx: RequestContext | None = None
98162
try:
99163
path = scope["path"]
100-
endpoint = self._endpoints.get(path)
164+
endpoint = endpoints.get(path)
101165
if not endpoint and scope["root_path"]:
102166
# The application was mounted at some root so try stripping the prefix.
103167
path = path.removeprefix(scope["root_path"])
104-
endpoint = self._endpoints.get(path)
105-
168+
endpoint = endpoints.get(path)
106169
if not endpoint:
107170
raise HTTPException(HTTPStatus.NOT_FOUND, [])
108171

@@ -381,6 +444,17 @@ async def _handle_error(
381444
)
382445
await send({"type": "http.response.body", "body": body, "more_body": False})
383446

447+
def _resolve_endpoints(self, service: _SVC) -> Mapping[str, Endpoint]:
448+
resolved_endpoints = self._endpoints(service)
449+
if self._interceptors:
450+
resolved_endpoints = {
451+
path: _apply_interceptors(
452+
endpoint, resolve_interceptors(self._interceptors)
453+
)
454+
for path, endpoint in resolved_endpoints.items()
455+
}
456+
return resolved_endpoints
457+
384458

385459
async def _send_stream_response_headers(
386460
send: ASGISendCallable, codec: Codec, compression_name: str, ctx: RequestContext

0 commit comments

Comments
 (0)