Skip to content

Commit 7e9a754

Browse files
pierrejeambrunnailo2c
authored andcommitted
AIP-84 Add safe url helper method (apache#47577)
* AIP-84 Add safe url helper method * Following code review
1 parent 405bb72 commit 7e9a754

File tree

4 files changed

+71
-5
lines changed

4 files changed

+71
-5
lines changed

airflow/api_fastapi/core_api/routes/public/login.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19-
from fastapi import Request, status
19+
from fastapi import HTTPException, Request, status
2020
from fastapi.responses import RedirectResponse
2121

2222
from airflow.api_fastapi.common.router import AirflowRouter
2323
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
24+
from airflow.api_fastapi.core_api.security import is_safe_url
2425

2526
login_router = AirflowRouter(tags=["Login"], prefix="/login")
2627

@@ -33,6 +34,9 @@ def login(request: Request, next: None | str = None) -> RedirectResponse:
3334
"""Redirect to the login URL depending on the AuthManager configured."""
3435
login_url = request.app.state.auth_manager.get_url_login()
3536

37+
if next and not is_safe_url(next):
38+
raise HTTPException(status_code=400, detail="Invalid or unsafe next URL")
39+
3640
if next:
3741
login_url += f"?next={next}"
3842
return RedirectResponse(login_url)

airflow/api_fastapi/core_api/security.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from __future__ import annotations
1818

1919
from functools import cache
20+
from pathlib import Path
2021
from typing import TYPE_CHECKING, Annotated, Callable
22+
from urllib.parse import urljoin, urlparse
2123

2224
from fastapi import Depends, HTTPException, Request, status
2325
from fastapi.security import OAuth2PasswordBearer
@@ -253,3 +255,23 @@ def _requires_access(
253255
) -> None:
254256
if not is_authorized_callback():
255257
raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden")
258+
259+
260+
def is_safe_url(target_url: str) -> bool:
261+
"""
262+
Check that the URL is safe.
263+
264+
Needs to belong to the same domain as base_url, use HTTP or HTTPS (no JavaScript/data schemes),
265+
is a valid normalized path.
266+
"""
267+
base_url = conf.get("api", "base_url")
268+
269+
parsed_base = urlparse(base_url)
270+
parsed_target = urlparse(urljoin(base_url, target_url)) # Resolves relative URLs
271+
272+
target_path = Path(parsed_target.path).resolve()
273+
274+
if target_path and parsed_base.path and not target_path.is_relative_to(parsed_base.path):
275+
return False
276+
277+
return parsed_target.scheme in {"http", "https"} and parsed_target.netloc == parsed_base.netloc

tests/api_fastapi/core_api/routes/public/test_login.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import pytest
2222

23+
from tests_common.test_utils.config import conf_vars
24+
2325
AUTH_MANAGER_LOGIN_URL = "http://some_login_url"
2426

2527
pytestmark = pytest.mark.db_test
@@ -39,11 +41,11 @@ class TestGetLogin(TestLoginEndpoint):
3941
[
4042
{},
4143
{"next": None},
42-
{"next": "http://localhost:28080"},
43-
{"next": "http://localhost:28080", "other_param": "something_else"},
44+
{"next": "http://localhost:8080"},
45+
{"next": "http://localhost:8080", "other_param": "something_else"},
4446
],
4547
)
46-
def test_should_respond_308(self, test_client, params):
48+
def test_should_respond_307(self, test_client, params):
4749
response = test_client.get("/public/login", follow_redirects=False, params=params)
4850

4951
assert response.status_code == 307
@@ -52,3 +54,16 @@ def test_should_respond_308(self, test_client, params):
5254
if params.get("next")
5355
else AUTH_MANAGER_LOGIN_URL
5456
)
57+
58+
@pytest.mark.parametrize(
59+
"params",
60+
[
61+
{"next": "http://fake_domain.com:8080"},
62+
{"next": "http://localhost:8080/../../up"},
63+
],
64+
)
65+
@conf_vars({("api", "base_url"): "http://localhost:8080/prefix"})
66+
def test_should_respond_400(self, test_client, params):
67+
response = test_client.get("/public/login", follow_redirects=False, params=params)
68+
69+
assert response.status_code == 400

tests/api_fastapi/core_api/test_security.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from airflow.api_fastapi.app import create_app
2626
from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity
2727
from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
28-
from airflow.api_fastapi.core_api.security import get_user, requires_access_dag
28+
from airflow.api_fastapi.core_api.security import get_user, is_safe_url, requires_access_dag
2929

3030
from tests_common.test_utils.config import conf_vars
3131

@@ -110,3 +110,28 @@ def test_requires_access_dag_unauthorized(self, mock_get_auth_manager):
110110
requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock())
111111

112112
auth_manager.is_authorized_dag.assert_called_once()
113+
114+
@pytest.mark.parametrize(
115+
"url, expected_is_safe",
116+
[
117+
("https://server_base_url.com/prefix/some_page?with_param=3", True),
118+
("https://server_base_url.com/prefix/", True),
119+
("https://server_base_url.com/prefix", True),
120+
("/prefix/some_other", True),
121+
("prefix/some_other", True),
122+
# Relative path, will go up one level escaping the prefix folder
123+
("some_other", False),
124+
("./some_other", False),
125+
# wrong scheme
126+
("javascript://server_base_url.com/prefix/some_page?with_param=3", False),
127+
# wrong netloc
128+
("https://some_netlock.com/prefix/some_page?with_param=3", False),
129+
# Absolute path escaping the prefix folder
130+
("/some_other_page/", False),
131+
# traversal, escaping the `prefix` folder
132+
("/../../../../some_page?with_param=3", False),
133+
],
134+
)
135+
@conf_vars({("api", "base_url"): "https://server_base_url.com/prefix"})
136+
def test_is_safe_url(self, url, expected_is_safe):
137+
assert is_safe_url(url) == expected_is_safe

0 commit comments

Comments
 (0)