Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions docs/interceptors.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,65 +175,91 @@ to be able to intercept RPC messages. However, many interceptors, such as for au
tracing, only need access to headers and not messages. Connect provides a metadata interceptor
protocol that can be implemented to work with any RPC type.

An authentication interceptor checking bearer tokens may look like this:
An authentication interceptor checking bearer tokens and storing them to a context variable may look like this:

=== "ASGI"

```python
class AuthInterceptor:
from contextvars import ContextVar, Token

_auth_token = ContextVar["auth_token"]("current_auth_token")

class ServerAuthInterceptor:
def __init__(self, valid_tokens: list[str]):
self._valid_tokens = valid_tokens

async def on_start(self, ctx: RequestContext):
async def on_start(self, ctx: RequestContext) -> Token["auth_token"]:
authorization = ctx.request_headers().get("authorization")
if not authorization or not authorization.startswith("Bearer "):
raise ConnectError(Code.UNAUTHENTICATED)
token = authorization[len("Bearer "):]
if token not in valid_tokens:
if token not in self._valid_tokens:
raise ConnectError(Code.PERMISSION_DENIED)
return _auth_token.set(token)

async def on_end(self, token: Token["auth_token"], ctx: RequestContext):
_auth_token.reset(token)
```

=== "WSGI"

```python
class AuthInterceptor:
from contextvars import ContextVar, Token

_auth_token = ContextVar["auth_token"]("current_auth_token")

class ServerAuthInterceptor:
def __init__(self, valid_tokens: list[str]):
self._valid_tokens = valid_tokens

def on_start(self, ctx: RequestContext):
def on_start(self, ctx: RequestContext) -> Token["auth_token"]:
authorization = ctx.request_headers().get("authorization")
if not authorization or not authorization.startswith("Bearer "):
raise ConnectError(Code.UNAUTHENTICATED)
token = authorization[len("Bearer "):]
if token not in valid_tokens:
if token not in self._valid_tokens:
raise ConnectError(Code.PERMISSION_DENIED)
return _auth_token.set(token)

def on_end(self, token: Token["auth_token"], ctx: RequestContext):
_auth_token.reset(token)
```

`on_start` can return any value, which is passed to the optional `on_end` method. This can be
used, for example, to record the time of execution for the method.
`on_start` can return any value, which is passed to the optional `on_end` method. Here, we
return the token to reset the context variable.

=== "ASGI"
Clients can add an interceptor that reads the token from the context variable and populates
the authorization header.

=== "Async"

```python
import time
from contextvars import ContextVar

class TimingInterceptor:
async def on_start(self, ctx: RequestContext) -> float:
return time.perf_counter()
_auth_token = ContextVar["auth_token"]("current_auth_token")

async def on_end(self, token: float, ctx: RequestContext):
print(f"Method took {} seconds.", token - time.perf_counter())
class ClientAuthInterceptor:
async def on_start(self, ctx: RequestContext):
auth_token = _auth_token.get(None)
if auth_token:
ctx.request_headers()["authorization"] = f"Bearer {auth_token}"
```

=== "WSGI"
=== "Sync"

```python
import time
from contextvars import ContextVar

class TimingInterceptor:
def on_start(self, ctx: RequestContext):
return time.perf_counter()
_auth_token = ContextVar["auth_token"]("current_auth_token")

def on_end(self, token: float, ctx: RequestContext):
print(f"Method took {} seconds.", token - time.perf_counter())
class ClientAuthInterceptor:
def on_start(self, ctx: RequestContext):
auth_token = _auth_token.get(None)
if auth_token:
ctx.request_headers()["authorization"] = f"Bearer {auth_token}"
```

Note that in the client interceptor, we do not need to define `on_end`.

The above interceptors would allow a server to receive and validate an auth token and automatically
propagate it to the authorization header of backend calls.