|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | import base64 |
| 5 | +import ipaddress |
5 | 6 | import os |
6 | 7 | import re |
7 | 8 | import ssl |
8 | 9 | import tempfile |
9 | 10 | import textwrap |
10 | 11 | from typing import Mapping, Optional, Set |
| 12 | +from urllib.parse import urlsplit |
11 | 13 |
|
12 | 14 | import fido2.features |
13 | 15 | from flask import Flask, has_request_context, request |
@@ -188,16 +190,67 @@ def determine_rp_id(explicit_id: Optional[str] = None) -> str: |
188 | 190 | return configured_id.strip() |
189 | 191 |
|
190 | 192 | if has_request_context(): |
191 | | - host = request.host.split(":", 1)[0].strip().lower() |
| 193 | + host = _resolve_request_host() |
192 | 194 | if host in {"", None}: |
193 | 195 | return "localhost" |
| 196 | + try: |
| 197 | + if ipaddress.ip_address(host).is_loopback: |
| 198 | + return "localhost" |
| 199 | + except ValueError: |
| 200 | + pass |
194 | 201 | if host in {"127.0.0.1", "::1"}: |
195 | 202 | return "localhost" |
196 | 203 | return host |
197 | 204 |
|
198 | 205 | return "localhost" |
199 | 206 |
|
200 | 207 |
|
| 208 | +def _resolve_request_host() -> Optional[str]: |
| 209 | + """Return the current request host without port decoration.""" |
| 210 | + |
| 211 | + if not has_request_context(): |
| 212 | + return None |
| 213 | + |
| 214 | + for raw_host in ( |
| 215 | + request.headers.get("Host"), |
| 216 | + request.environ.get("HTTP_HOST"), |
| 217 | + request.environ.get("SERVER_NAME"), |
| 218 | + ): |
| 219 | + host = _normalise_request_host(raw_host) |
| 220 | + if host: |
| 221 | + return host |
| 222 | + |
| 223 | + return None |
| 224 | + |
| 225 | + |
| 226 | +def _normalise_request_host(raw_host: Optional[str]) -> Optional[str]: |
| 227 | + """Normalise a raw host header into a lowercase hostname or IP literal.""" |
| 228 | + |
| 229 | + if not isinstance(raw_host, str): |
| 230 | + return None |
| 231 | + |
| 232 | + host = raw_host.strip().lower() |
| 233 | + if not host: |
| 234 | + return None |
| 235 | + |
| 236 | + if host.startswith("["): |
| 237 | + closing_index = host.find("]") |
| 238 | + if closing_index != -1: |
| 239 | + unwrapped = host[1:closing_index].strip() |
| 240 | + return unwrapped or None |
| 241 | + |
| 242 | + if host.count(":") > 1: |
| 243 | + # Treat unbracketed multi-colon values as IPv6 literals without ports. |
| 244 | + return host |
| 245 | + |
| 246 | + parsed = urlsplit(f"//{host}") |
| 247 | + normalised = parsed.hostname |
| 248 | + if isinstance(normalised, str) and normalised.strip(): |
| 249 | + return normalised.strip().lower() |
| 250 | + |
| 251 | + return host |
| 252 | + |
| 253 | + |
201 | 254 | def build_rp_entity( |
202 | 255 | rp_data: Optional[Mapping[str, str]] = None, |
203 | 256 | *, |
|
0 commit comments