|
1 | | -import base64 |
2 | | -import logging |
3 | | -from abc import ABC, abstractmethod |
4 | | -from typing import Any, Awaitable, Callable, Dict, Mapping, Optional |
| 1 | +from typing import Awaitable, Callable |
5 | 2 |
|
6 | | -from itsdangerous import Serializer, URLSafeTimedSerializer # noqa |
7 | | -from itsdangerous.exc import BadSignature, SignatureExpired |
8 | | - |
9 | | -from blacksheep.cookies import Cookie |
10 | 3 | from blacksheep.messages import Request, Response |
11 | | -from blacksheep.settings.json import json_settings |
12 | | -from blacksheep.utils import ensure_str |
13 | | - |
14 | | - |
15 | | -def get_logger(): |
16 | | - logger = logging.getLogger("blacksheep.sessions") |
17 | | - logger.setLevel(logging.INFO) |
18 | | - return logger |
19 | | - |
20 | | - |
21 | | -class Session: |
22 | | - def __init__(self, values: Optional[Mapping[str, Any]] = None) -> None: |
23 | | - if values is None: |
24 | | - values = {} |
25 | | - self._modified = False |
26 | | - self._values = dict(values) |
27 | | - |
28 | | - @property |
29 | | - def modified(self) -> bool: |
30 | | - return self._modified |
31 | | - |
32 | | - def get(self, name: str, default: Any = None) -> Any: |
33 | | - return self._values.get(name, default) |
34 | | - |
35 | | - def set(self, name: str, value: Any) -> None: |
36 | | - self._modified = True |
37 | | - self._values[name] = value |
38 | | - |
39 | | - def update(self, values: Mapping[str, Any]) -> None: |
40 | | - self._modified = True |
41 | | - self._values.update(values) |
42 | | - |
43 | | - def __getitem__(self, name: str) -> Any: |
44 | | - return self._values[name] |
45 | | - |
46 | | - def __setitem__(self, name: str, value: Any) -> None: |
47 | | - self._modified = True |
48 | | - self._values[name] = value |
49 | | - |
50 | | - def __delitem__(self, name: str) -> None: |
51 | | - self._modified = True |
52 | | - del self._values[name] |
| 4 | +from blacksheep.sessions.abc import Session, SessionSerializer, SessionStore |
53 | 5 |
|
54 | | - def __contains__(self, name: str) -> bool: |
55 | | - return name in self._values |
56 | | - |
57 | | - def __len__(self) -> int: |
58 | | - return len(self._values) |
59 | | - |
60 | | - def __eq__(self, o: object) -> bool: |
61 | | - if self is o: |
62 | | - return True |
63 | | - if isinstance(o, Session): |
64 | | - return self._values == o._values |
65 | | - return self._values == o |
66 | | - |
67 | | - def clear(self) -> None: |
68 | | - self._modified = True |
69 | | - self._values.clear() |
70 | | - |
71 | | - def to_dict(self) -> Dict[str, Any]: |
72 | | - return self._values.copy() |
73 | | - |
74 | | - |
75 | | -class SessionSerializer(ABC): |
76 | | - @abstractmethod |
77 | | - def read(self, value: str) -> Session: |
78 | | - """Creates an instance of Session from a string representation.""" |
79 | | - |
80 | | - @abstractmethod |
81 | | - def write(self, session: Session) -> str: |
82 | | - """Creates the string representation of a session.""" |
83 | | - |
84 | | - |
85 | | -class JSONSerializer(SessionSerializer): |
86 | | - def read(self, value: str) -> Session: |
87 | | - return Session(json_settings.loads(value)) |
88 | | - |
89 | | - def write(self, session: Session) -> str: |
90 | | - return json_settings.dumps(session.to_dict()) |
| 6 | +__all__ = [ |
| 7 | + "Session", |
| 8 | + "SessionMiddleware", |
| 9 | + "SessionStore", |
| 10 | + "SessionSerializer", |
| 11 | +] |
91 | 12 |
|
92 | 13 |
|
93 | 14 | class SessionMiddleware: |
94 | | - def __init__( |
95 | | - self, |
96 | | - secret_key: str, |
97 | | - *, |
98 | | - session_cookie: str = "session", |
99 | | - serializer: Optional[SessionSerializer] = None, |
100 | | - signer: Optional[Serializer] = None, |
101 | | - session_max_age: Optional[int] = None, |
102 | | - ) -> None: |
103 | | - self._signer = signer or URLSafeTimedSerializer(secret_key) |
104 | | - self._serializer = serializer or JSONSerializer() |
105 | | - self._session_cookie = session_cookie |
106 | | - self._logger = get_logger() |
107 | | - if session_max_age is not None and session_max_age < 1: |
108 | | - raise ValueError("session_max_age must be a positive number greater than 0") |
109 | | - self.session_max_age = session_max_age |
| 15 | + """ |
| 16 | + Middleware for managing user sessions in a BlackSheep application. |
110 | 17 |
|
111 | | - def try_read_session(self, raw_value: str) -> Session: |
112 | | - try: |
113 | | - if self.session_max_age: |
114 | | - assert isinstance(self._signer, URLSafeTimedSerializer), ( |
115 | | - "To use a session_max_age, the configured signer must be of " |
116 | | - + " TimestampSigner type" |
117 | | - ) |
118 | | - unsigned_value = self._signer.loads( |
119 | | - raw_value, max_age=self.session_max_age |
120 | | - ) |
121 | | - else: |
122 | | - unsigned_value = self._signer.loads(raw_value) |
123 | | - except SignatureExpired: |
124 | | - self._logger.info("The session signature has expired.") |
125 | | - return Session() |
126 | | - except BadSignature: |
127 | | - # the client might be sending forged tokens |
128 | | - self._logger.info("The session signature verification failed.") |
129 | | - return Session() |
| 18 | + This middleware loads the session from the provided session store at the beginning |
| 19 | + of the request, attaches it to the request object, and saves the session back to |
| 20 | + the store if it was modified during request processing. |
130 | 21 |
|
131 | | - # in this case, we don't try because if the signature verification worked, |
132 | | - # we expect the value to be valid - if reading fails here it's a bug in |
133 | | - # in the serializer class |
134 | | - return self._serializer.read(base64.b64decode(unsigned_value).decode("utf8")) |
| 22 | + Args: |
| 23 | + store (SessionStore): The session store used to load and save session data. |
135 | 24 |
|
136 | | - def write_session(self, session: Session) -> str: |
137 | | - payload = base64.b64encode( |
138 | | - self._serializer.write(session).encode("utf8") |
139 | | - ).decode() |
140 | | - return ensure_str(self._signer.dumps(payload)) # type: ignore |
| 25 | + Usage: |
| 26 | + Add this middleware to your application to enable session support. |
| 27 | + """ |
141 | 28 |
|
142 | | - def prepare_cookie(self, value: str) -> Cookie: |
143 | | - return Cookie(self._session_cookie, value, path="/", http_only=True) |
| 29 | + def __init__(self, store: SessionStore) -> None: |
| 30 | + self._store = store |
144 | 31 |
|
145 | 32 | async def __call__( |
146 | 33 | self, request: Request, handler: Callable[[Request], Awaitable[Response]] |
147 | 34 | ) -> Response: |
148 | | - session: Optional[Session] = None |
149 | | - current_session_value = request.cookies.get(self._session_cookie, None) |
150 | | - if current_session_value: |
151 | | - session = self.try_read_session(current_session_value) |
152 | | - else: |
153 | | - session = Session() |
| 35 | + session = await self._store.load(request) |
154 | 36 | request.session = session |
155 | | - |
156 | 37 | response = await handler(request) |
157 | | - |
158 | 38 | if session.modified: |
159 | | - response.set_cookie(self.prepare_cookie(self.write_session(session))) |
| 39 | + await self._store.save(request, response, session) |
160 | 40 | return response |
0 commit comments