|
6 | 6 | import secrets |
7 | 7 | import threading |
8 | 8 | import webbrowser |
9 | | -from typing import Optional, Tuple |
| 9 | +from typing import Callable, Optional, Tuple |
10 | 10 |
|
11 | 11 | from asgiref.typing import ( |
12 | 12 | ASGI3Application, |
13 | 13 | ASGIReceiveCallable, |
14 | 14 | ASGISendCallable, |
15 | 15 | ASGISendEvent, |
| 16 | + HTTPResponseStartEvent, |
16 | 17 | Scope, |
17 | 18 | ) |
18 | 19 |
|
@@ -111,40 +112,15 @@ async def __call__( |
111 | 112 | ) -> None: |
112 | 113 | if scope["type"] != "http" or scope["path"] != "/" or len(self.script) == 0: |
113 | 114 | return await self.app(scope, receive, send) |
114 | | - intercept = True |
115 | | - body = b"" |
116 | | - |
117 | | - async def rewrite_send(event: ASGISendEvent) -> None: |
118 | | - nonlocal intercept |
119 | | - nonlocal body |
120 | | - |
121 | | - if intercept: |
122 | | - if event["type"] == "http.response.start": |
123 | | - # Must remove Content-Length, if present; if we insert our |
124 | | - # scripts, it won't be correct anymore |
125 | | - event["headers"] = [ |
126 | | - (name, value) |
127 | | - for (name, value) in event["headers"] |
128 | | - if name.decode("ascii").lower() != "content-length" |
129 | | - ] |
130 | | - elif event["type"] == "http.response.body": |
131 | | - body += event["body"] |
132 | | - if b"</head>" in body: |
133 | | - event["body"] = body.replace(b"</head>", self.script, 1) |
134 | | - body = b"" # Allow gc |
135 | | - intercept = False |
136 | | - elif "more_body" in event and event["more_body"]: |
137 | | - # DO NOT send the response; wait for more data |
138 | | - return |
139 | | - else: |
140 | | - # The entire response was seen, and we never encountered |
141 | | - # any </head>. Just send everything we have |
142 | | - event["body"] = body |
143 | | - body = b"" # Allow gc |
144 | | - |
145 | | - return await send(event) |
146 | | - |
147 | | - await self.app(scope, receive, rewrite_send) |
| 115 | + |
| 116 | + def mangle_callback(body: bytes) -> Tuple[bytes, bool]: |
| 117 | + if b"</head>" in body: |
| 118 | + return (body.replace(b"</head>", self.script, 1), True) |
| 119 | + else: |
| 120 | + return (body, False) |
| 121 | + |
| 122 | + mangler = ResponseMangler(send, mangle_callback) |
| 123 | + await self.app(scope, receive, mangler.send) |
148 | 124 |
|
149 | 125 |
|
150 | 126 | # PARENT PROCESS ------------------------------------------------------------ |
@@ -231,3 +207,97 @@ async def process_request( |
231 | 207 |
|
232 | 208 | async with serve(reload_server, "127.0.0.1", port, process_request=process_request): |
233 | 209 | await asyncio.Future() # wait forever |
| 210 | + |
| 211 | + |
| 212 | +class ResponseMangler: |
| 213 | + """A class that assists with intercepting and rewriting response bodies being sent |
| 214 | + over ASGI. This would be easy if not for 1) response bodies are potentially sent in |
| 215 | + chunks, over multiple events; 2) the first response event we receive is the one that |
| 216 | + contains the Content-Length, which can be affected when we do rewriting later on. |
| 217 | + The ResponseMangler handles the buffering and content-length rewriting, leaving the |
| 218 | + caller to only have to worry about the actual body-modifying logic. |
| 219 | + """ |
| 220 | + |
| 221 | + def __init__( |
| 222 | + self, send: ASGISendCallable, mangler: Callable[[bytes], Tuple[bytes, bool]] |
| 223 | + ) -> None: |
| 224 | + # The underlying ASGI send function |
| 225 | + self._send = send |
| 226 | + # The caller-provided logic for rewriting the body. Takes a single `bytes` |
| 227 | + # argument that is _all_ of the body bytes seen _so far_, and returns a tuple of |
| 228 | + # (bytes, bool) where the bytes are the (possibly modified) body bytes and the |
| 229 | + # bool is True if the mangler does not care to see any more data. |
| 230 | + self._mangler = mangler |
| 231 | + |
| 232 | + # If True, the mangler is done and any further data can simply be passed along |
| 233 | + self._done: bool = False |
| 234 | + |
| 235 | + # Holds the http.response.start event, which may need its Content-Length header |
| 236 | + # rewritten before we send it |
| 237 | + self._response_start: Optional[HTTPResponseStartEvent] = None |
| 238 | + # All the response body bytes we have seen so far |
| 239 | + self._body: bytes = b"" |
| 240 | + |
| 241 | + async def send(self, event: ASGISendEvent) -> None: |
| 242 | + if self._done: |
| 243 | + await self._send(event) |
| 244 | + return |
| 245 | + |
| 246 | + if event["type"] == "http.response.start": |
| 247 | + self._response_start = event |
| 248 | + elif event["type"] == "http.response.body": |
| 249 | + # This check is mostly to make pyright happy |
| 250 | + if self._response_start is None: |
| 251 | + raise AssertionError( |
| 252 | + "http.response.body ASGI event sent before http.response.start" |
| 253 | + ) |
| 254 | + |
| 255 | + # Add the newly received body data to what we've seen already |
| 256 | + self._body += event["body"] |
| 257 | + # Snapshot length before we mess with the body |
| 258 | + old_len = len(self._body) |
| 259 | + # Mangle away! If done is True, the mangler doesn't want to do any further |
| 260 | + # mangling. |
| 261 | + self._body, done = self._mangler(self._body) |
| 262 | + |
| 263 | + new_len = len(self._body) |
| 264 | + if new_len != old_len: |
| 265 | + # The mangling check changed the length of the body. Add the difference |
| 266 | + # to the content-length header (if content-length is even present) |
| 267 | + _add_to_content_length(self._response_start, new_len - old_len) |
| 268 | + |
| 269 | + more_body = event.get("more_body", False) |
| 270 | + |
| 271 | + if done or not more_body: |
| 272 | + # Either we've seen the whole body by now (`not more_body`) or the |
| 273 | + # mangler has seen all the data it cares to (`done`). Either way, we can |
| 274 | + # send all the data we have. |
| 275 | + self._done = True |
| 276 | + await self._send(self._response_start) |
| 277 | + await self._send( |
| 278 | + { |
| 279 | + "type": "http.response.body", |
| 280 | + "body": self._body, |
| 281 | + "more_body": more_body, |
| 282 | + } |
| 283 | + ) |
| 284 | + # Allow gc |
| 285 | + self._response_start = None |
| 286 | + self._body = b"" |
| 287 | + else: |
| 288 | + # If we get here, then the mangler isn't done and we are expecting to |
| 289 | + # see more data. Do nothing. |
| 290 | + pass |
| 291 | + |
| 292 | + |
| 293 | +def _add_to_content_length(event: HTTPResponseStartEvent, offset: int) -> None: |
| 294 | + """If event has a Content-Length header, add the specified number of bytes to it |
| 295 | + (may be negative)""" |
| 296 | + event["headers"] = [ |
| 297 | + ( |
| 298 | + (name, str(int(value) + offset).encode("latin-1")) |
| 299 | + if name.decode("ascii").lower() == "content-length" |
| 300 | + else (name, value) |
| 301 | + ) |
| 302 | + for (name, value) in event["headers"] |
| 303 | + ] |
0 commit comments