Skip to content

Commit c6a452f

Browse files
committed
Add doc for client metadata interceptor
Signed-off-by: Anuraag Agrawal <[email protected]>
1 parent 5bae39f commit c6a452f

File tree

1 file changed

+46
-20
lines changed

1 file changed

+46
-20
lines changed

docs/interceptors.md

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -175,28 +175,40 @@ to be able to intercept RPC messages. However, many interceptors, such as for au
175175
tracing, only need access to headers and not messages. Connect provides a metadata interceptor
176176
protocol that can be implemented to work with any RPC type.
177177

178-
An authentication interceptor checking bearer tokens may look like this:
178+
An authentication interceptor checking bearer tokens and storing them to a context variable may look like this:
179179

180180
=== "ASGI"
181181

182182
```python
183-
class AuthInterceptor:
183+
from contextvars import ContextVar, Token
184+
185+
_auth_token = ContextVar["auth_token"]("current_auth_token")
186+
187+
class ServerAuthInterceptor:
184188
def __init__(self, valid_tokens: list[str]):
185189
self._valid_tokens = valid_tokens
186190

187-
async def on_start(self, ctx: RequestContext):
191+
async def on_start(self, ctx: RequestContext) -> Token["auth_token"]:
188192
authorization = ctx.request_headers().get("authorization")
189193
if not authorization or not authorization.startswith("Bearer "):
190194
raise ConnectError(Code.UNAUTHENTICATED)
191195
token = authorization[len("Bearer "):]
192196
if token not in valid_tokens:
193197
raise ConnectError(Code.PERMISSION_DENIED)
198+
return _auth_token.set(token)
199+
200+
async def on_end(self, token: Token["auth_token"], ctx: RequestContext):
201+
_auth_token.reset(token)
194202
```
195203

196204
=== "WSGI"
197205

198206
```python
199-
class AuthInterceptor:
207+
from contextvars import ContextVar, Token
208+
209+
_auth_token = ContextVar["auth_token"]("current_auth_token")
210+
211+
class ServerAuthInterceptor:
200212
def __init__(self, valid_tokens: list[str]):
201213
self._valid_tokens = valid_tokens
202214

@@ -207,33 +219,47 @@ An authentication interceptor checking bearer tokens may look like this:
207219
token = authorization[len("Bearer "):]
208220
if token not in valid_tokens:
209221
raise ConnectError(Code.PERMISSION_DENIED)
222+
return _auth_token.set(token)
223+
224+
def on_end(self, token: Token["auth_token"], ctx: RequestContext):
225+
_auth_token.reset(token)
210226
```
211227

212-
`on_start` can return any value, which is passed to the optional `on_end` method. This can be
213-
used, for example, to record the time of execution for the method.
228+
`on_start` can return any value, which is passed to the optional `on_end` method. Here, we
229+
return the token to reset the context variable.
214230

215-
=== "ASGI"
231+
Clients can add an interceptor that reads the token from the context variable and populates
232+
the authorization header.
233+
234+
=== "Async"
216235

217236
```python
218-
import time
237+
from contextvars import ContextVar, Token
219238

220-
class TimingInterceptor:
221-
async def on_start(self, ctx: RequestContext) -> float:
222-
return time.perf_counter()
239+
_auth_token = ContextVar["auth_token"]("current_auth_token")
223240

224-
async def on_end(self, token: float, ctx: RequestContext):
225-
print(f"Method took {} seconds.", token - time.perf_counter())
241+
class ClientAuthInterceptor:
242+
async def on_start(self, ctx: RequestContext) -> Token["auth_token"]:
243+
auth_token = _auth_token.get(None)
244+
if auth_token:
245+
ctx.request_headers()["authorization"] = f"Bearer {auth_token}"
226246
```
227247

228-
=== "WSGI"
248+
=== "Sync"
229249

230250
```python
231-
import time
251+
from contextvars import ContextVar, Token
232252

233-
class TimingInterceptor:
234-
def on_start(self, ctx: RequestContext):
235-
return time.perf_counter()
253+
_auth_token = ContextVar["auth_token"]("current_auth_token")
236254

237-
def on_end(self, token: float, ctx: RequestContext):
238-
print(f"Method took {} seconds.", token - time.perf_counter())
255+
class ClientAuthInterceptor:
256+
def on_start(self, ctx: RequestContext) -> Token["auth_token"]:
257+
auth_token = _auth_token.get(None)
258+
if auth_token:
259+
ctx.request_headers()["authorization"] = f"Bearer {auth_token}"
239260
```
261+
262+
Note that in the client interceptor, we do not need to define `on_end`.
263+
264+
The above interceptors would allow a server to receive and validate an auth token and automatically
265+
propagate it to the authorization header of backend calls.

0 commit comments

Comments
 (0)