|
4 | 4 | import re |
5 | 5 | from dataclasses import dataclass |
6 | 6 | from typing import Any, Optional |
7 | | -from urllib.parse import urlparse, urlunparse |
| 7 | +from urllib.parse import ParseResult, urlparse, urlunparse |
8 | 8 |
|
9 | 9 | from starlette.datastructures import Headers |
10 | 10 | from starlette.requests import Request |
@@ -47,67 +47,75 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An |
47 | 47 | parsed_upstream_url = urlparse(self.upstream_url) |
48 | 48 |
|
49 | 49 | for link in get_links(data): |
50 | | - href = link.get("href") |
51 | | - if not href: |
52 | | - continue |
53 | | - |
54 | 50 | try: |
55 | | - parsed_link = urlparse(href) |
56 | | - |
57 | | - if parsed_link.netloc not in [ |
58 | | - parsed_req_url.netloc, |
59 | | - parsed_upstream_url.netloc, |
60 | | - ]: |
61 | | - logger.debug( |
62 | | - "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", |
63 | | - href, |
64 | | - parsed_req_url.netloc, |
65 | | - parsed_upstream_url.netloc, |
66 | | - ) |
67 | | - continue |
68 | | - |
69 | | - # If the link path is not a descendant of the upstream path, don't transform it |
70 | | - if parsed_upstream_url.path != "/" and not parsed_link.path.startswith( |
71 | | - parsed_upstream_url.path |
72 | | - ): |
73 | | - logger.debug( |
74 | | - "Ignoring link %s because it is not descendant of upstream path (%s)", |
75 | | - href, |
76 | | - parsed_upstream_url.path, |
77 | | - ) |
78 | | - continue |
79 | | - |
80 | | - # Replace the upstream host with the client's host |
81 | | - if parsed_link.netloc == parsed_upstream_url.netloc: |
82 | | - parsed_link = parsed_link._replace( |
83 | | - netloc=parsed_req_url.netloc |
84 | | - )._replace(scheme=parsed_req_url.scheme) |
85 | | - |
86 | | - # Rewrite the link path |
87 | | - if parsed_upstream_url.path != "/" and parsed_link.path.startswith( |
88 | | - parsed_upstream_url.path |
89 | | - ): |
90 | | - parsed_link = parsed_link._replace( |
91 | | - path=parsed_link.path[len(parsed_upstream_url.path) :] |
92 | | - ) |
93 | | - |
94 | | - # Add the root_path to the link if it exists |
95 | | - if self.root_path: |
96 | | - parsed_link = parsed_link._replace( |
97 | | - path=f"{self.root_path}{parsed_link.path}" |
98 | | - ) |
99 | | - |
100 | | - logger.debug( |
101 | | - "Rewriting %r link %r to %r", |
102 | | - link.get("rel"), |
103 | | - href, |
104 | | - urlunparse(parsed_link), |
105 | | - ) |
106 | | - link["href"] = urlunparse(parsed_link) |
107 | | - |
| 51 | + self._update_link(link, parsed_req_url, parsed_upstream_url) |
108 | 52 | except Exception as e: |
109 | 53 | logger.error( |
110 | | - "Failed to parse link href %r, (ignoring): %s", href, str(e) |
| 54 | + "Failed to parse link href %r, (ignoring): %s", |
| 55 | + link.get("href"), |
| 56 | + str(e), |
111 | 57 | ) |
112 | | - |
113 | 58 | return data |
| 59 | + |
| 60 | + def _update_link( |
| 61 | + self, link: dict[str, Any], request_url: ParseResult, upstream_url: ParseResult |
| 62 | + ) -> None: |
| 63 | + """ |
| 64 | + Ensure that link hrefs that are local to upstream url are rewritten as local to |
| 65 | + the proxy. |
| 66 | + """ |
| 67 | + if "href" not in link: |
| 68 | + logger.warning("Link %r has no href", link) |
| 69 | + return |
| 70 | + |
| 71 | + parsed_link = urlparse(link["href"]) |
| 72 | + |
| 73 | + if parsed_link.netloc not in [ |
| 74 | + request_url.netloc, |
| 75 | + upstream_url.netloc, |
| 76 | + ]: |
| 77 | + logger.debug( |
| 78 | + "Ignoring link %s because it is not for an endpoint behind this proxy (%s or %s)", |
| 79 | + link["href"], |
| 80 | + request_url.netloc, |
| 81 | + upstream_url.netloc, |
| 82 | + ) |
| 83 | + return |
| 84 | + |
| 85 | + # If the link path is not a descendant of the upstream path, don't transform it |
| 86 | + if upstream_url.path != "/" and not parsed_link.path.startswith( |
| 87 | + upstream_url.path |
| 88 | + ): |
| 89 | + logger.debug( |
| 90 | + "Ignoring link %s because it is not descendant of upstream path (%s)", |
| 91 | + link["href"], |
| 92 | + upstream_url.path, |
| 93 | + ) |
| 94 | + return |
| 95 | + |
| 96 | + # Replace the upstream host with the client's host |
| 97 | + if parsed_link.netloc == upstream_url.netloc: |
| 98 | + parsed_link = parsed_link._replace(netloc=request_url.netloc)._replace( |
| 99 | + scheme=request_url.scheme |
| 100 | + ) |
| 101 | + |
| 102 | + # Rewrite the link path |
| 103 | + if upstream_url.path != "/" and parsed_link.path.startswith(upstream_url.path): |
| 104 | + parsed_link = parsed_link._replace( |
| 105 | + path=parsed_link.path[len(upstream_url.path) :] |
| 106 | + ) |
| 107 | + |
| 108 | + # Add the root_path to the link if it exists |
| 109 | + if self.root_path: |
| 110 | + parsed_link = parsed_link._replace( |
| 111 | + path=f"{self.root_path}{parsed_link.path}" |
| 112 | + ) |
| 113 | + |
| 114 | + logger.debug( |
| 115 | + "Rewriting %r link %r to %r", |
| 116 | + link.get("rel"), |
| 117 | + link["href"], |
| 118 | + urlunparse(parsed_link), |
| 119 | + ) |
| 120 | + |
| 121 | + link["href"] = urlunparse(parsed_link) |
0 commit comments