Skip to content

Commit 856e7fc

Browse files
fix iter_origins
1 parent 10e459f commit 856e7fc

File tree

3 files changed

+37
-20
lines changed

3 files changed

+37
-20
lines changed

services/web/server/src/simcore_service_webserver/api_keys/_controller/rest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...login.decorators import login_required
2323
from ...models import RequestContext
2424
from ...security.decorators import permission_required
25-
from ...utils_aiohttp import envelope_json_response, iter_originating_hosts
25+
from ...utils_aiohttp import envelope_json_response, iter_origins
2626
from .. import _service
2727
from ..models import ApiKey
2828
from .rest_exceptions import handle_plugin_requests_exceptions
@@ -38,7 +38,7 @@ class ApiKeysPathParams(StrictRequestParameters):
3838

3939

4040
def _get_api_base_url(request: web.Request) -> str | None:
41-
originating_host = next(iter_originating_hosts(request), None)
41+
originating_host = next(iter_origins(request), None)
4242
if not originating_host:
4343
return None
4444

services/web/server/src/simcore_service_webserver/products/_web_middlewares.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .._meta import API_VTAG
1010
from ..constants import APP_PRODUCTS_KEY, RQ_PRODUCT_KEY
11-
from ..utils_aiohttp import iter_originating_hosts
11+
from ..utils_aiohttp import iter_origins
1212
from .models import Product
1313

1414
_logger = logging.getLogger(__name__)
@@ -22,7 +22,7 @@ def _get_default_product_name(app: web.Application) -> str:
2222
def _discover_product_by_hostname(request: web.Request) -> str | None:
2323
products: OrderedDict[str, Product] = request.app[APP_PRODUCTS_KEY]
2424
for product in products.values():
25-
for host in iter_originating_hosts(request):
25+
for host in iter_origins(request):
2626
if product.host_regex.search(host):
2727
product_name: str = product.name
2828
return product_name

services/web/server/src/simcore_service_webserver/utils_aiohttp.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -131,34 +131,51 @@ class NextPage(BaseModel, Generic[PageParameters]):
131131
parameters: PageParameters | None = None
132132

133133

134-
def iter_originating_hosts(request: web.Request) -> Iterator[str]:
134+
def iter_origins(request: web.Request) -> Iterator[str]:
135135
#
136136
# SEE https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-Host
137137
# SEE https://doc.traefik.io/traefik/getting-started/faq/#what-are-the-forwarded-headers-when-proxying-http-requests
138138
seen = set()
139139

140-
forwarded = request.headers.get("X-Forwarded-Host")
141-
if forwarded:
140+
fwd_protos = [
141+
p.strip()
142+
for p in request.headers.get("X-Forwarded-Proto").split(",")
143+
if p.strip()
144+
]
145+
fwd_hosts = [
146+
h.strip()
147+
for h in request.headers.get("X-Forwarded-Host").split(",")
148+
if h.strip()
149+
]
150+
fwd_ports = [
151+
pt.strip()
152+
for pt in request.headers.get("X-Forwarded-Port").split(",")
153+
if pt.strip()
154+
]
155+
156+
fwd_origins = [
157+
f"{proto}://{host}:{port}"
158+
for proto, host, port in zip(fwd_protos, fwd_hosts, fwd_ports, strict=False)
159+
]
160+
if fwd_origins:
142161
# X-Forwarded-Host can contain a comma-separated list of hosts
143162
# (when the request passes through multiple proxies)
144-
for host in forwarded.split(","):
145-
stripped_host = host.strip().partition(":")[0]
146-
if stripped_host and stripped_host not in seen:
147-
seen.add(stripped_host)
148-
yield host
163+
for origin in fwd_origins:
164+
if origin and origin not in seen:
165+
seen.add(origin)
166+
yield origin
149167

150168
# Fallback to request.host
151-
if request.host:
152-
host = request.host.partition(":")[0]
153-
if host not in seen:
154-
yield host
169+
if request.url:
170+
origin = f"{request.url.scheme}://{request.url.host}"
171+
if request.url.port:
172+
origin += f":{request.url.port}"
173+
yield origin
155174

156175

157176
def get_api_base_url(request: web.Request) -> str:
158-
originating_host = next(iter_originating_hosts(request))
177+
api_host = next(iter_origins(request))
159178
api_host = (
160-
f"api.{originating_host}"
161-
if not is_ip_address(originating_host)
162-
else originating_host
179+
f"api.{api_host}" if not is_ip_address(api_host) else api_host # in tests
163180
)
164181
return f"{request.url.with_host(api_host).with_port(None).with_path('')}"

0 commit comments

Comments
 (0)