@@ -175,65 +175,91 @@ to be able to intercept RPC messages. However, many interceptors, such as for au
175175tracing, only need access to headers and not messages. Connect provides a metadata interceptor
176176protocol that can be implemented to work with any RPC type.
177177
178- An authentication interceptor checking bearer tokens may look like this:
178+ An authentication interceptor checking bearer tokens and storing them to a context variable may look like this:
179179
180180=== "ASGI"
181181
182182 ```python
183- class AuthInterceptor:
183+ from contextvars import ContextVar, Token
184+
185+ _auth_token = ContextVar["auth_token"]("current_auth_token")
186+
187+ class ServerAuthInterceptor:
184188 def __init__(self, valid_tokens: list[str]):
185189 self._valid_tokens = valid_tokens
186190
187- async def on_start(self, ctx: RequestContext):
191+ async def on_start(self, ctx: RequestContext) -> Token["auth_token"] :
188192 authorization = ctx.request_headers().get("authorization")
189193 if not authorization or not authorization.startswith("Bearer "):
190194 raise ConnectError(Code.UNAUTHENTICATED)
191195 token = authorization[len("Bearer "):]
192- if token not in valid_tokens :
196+ if token not in self._valid_tokens :
193197 raise ConnectError(Code.PERMISSION_DENIED)
198+ return _auth_token.set(token)
199+
200+ async def on_end(self, token: Token["auth_token"], ctx: RequestContext):
201+ _auth_token.reset(token)
194202 ```
195203
196204=== "WSGI"
197205
198206 ```python
199- class AuthInterceptor:
207+ from contextvars import ContextVar, Token
208+
209+ _auth_token = ContextVar["auth_token"]("current_auth_token")
210+
211+ class ServerAuthInterceptor:
200212 def __init__(self, valid_tokens: list[str]):
201213 self._valid_tokens = valid_tokens
202214
203- def on_start(self, ctx: RequestContext):
215+ def on_start(self, ctx: RequestContext) -> Token["auth_token"] :
204216 authorization = ctx.request_headers().get("authorization")
205217 if not authorization or not authorization.startswith("Bearer "):
206218 raise ConnectError(Code.UNAUTHENTICATED)
207219 token = authorization[len("Bearer "):]
208- if token not in valid_tokens :
220+ if token not in self._valid_tokens :
209221 raise ConnectError(Code.PERMISSION_DENIED)
222+ return _auth_token.set(token)
223+
224+ def on_end(self, token: Token["auth_token"], ctx: RequestContext):
225+ _auth_token.reset(token)
210226 ```
211227
212- ` on_start ` can return any value, which is passed to the optional ` on_end ` method. This can be
213- used, for example, to record the time of execution for the method .
228+ ` on_start ` can return any value, which is passed to the optional ` on_end ` method. Here, we
229+ return the token to reset the context variable .
214230
215- === "ASGI"
231+ Clients can add an interceptor that reads the token from the context variable and populates
232+ the authorization header.
233+
234+ === "Async"
216235
217236 ```python
218- import time
237+ from contextvars import ContextVar
219238
220- class TimingInterceptor:
221- async def on_start(self, ctx: RequestContext) -> float:
222- return time.perf_counter()
239+ _auth_token = ContextVar["auth_token"]("current_auth_token")
223240
224- async def on_end(self, token: float, ctx: RequestContext):
225- print(f"Method took {} seconds.", token - time.perf_counter())
241+ class ClientAuthInterceptor:
242+ async def on_start(self, ctx: RequestContext):
243+ auth_token = _auth_token.get(None)
244+ if auth_token:
245+ ctx.request_headers()["authorization"] = f"Bearer {auth_token}"
226246 ```
227247
228- === "WSGI "
248+ === "Sync "
229249
230250 ```python
231- import time
251+ from contextvars import ContextVar
232252
233- class TimingInterceptor:
234- def on_start(self, ctx: RequestContext):
235- return time.perf_counter()
253+ _auth_token = ContextVar["auth_token"]("current_auth_token")
236254
237- def on_end(self, token: float, ctx: RequestContext):
238- print(f"Method took {} seconds.", token - time.perf_counter())
255+ class ClientAuthInterceptor:
256+ def on_start(self, ctx: RequestContext):
257+ auth_token = _auth_token.get(None)
258+ if auth_token:
259+ ctx.request_headers()["authorization"] = f"Bearer {auth_token}"
239260 ```
261+
262+ Note that in the client interceptor, we do not need to define ` on_end ` .
263+
264+ The above interceptors would allow a server to receive and validate an auth token and automatically
265+ propagate it to the authorization header of backend calls.
0 commit comments