Skip to content

Commit e9e6ac2

Browse files
authored
Handle invalid execution API urls gracefully in supervisor (#53082)
1 parent 4231a87 commit e9e6ac2

File tree

3 files changed

+96
-2
lines changed

3 files changed

+96
-2
lines changed

task-sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
TextIO,
4343
cast,
4444
)
45+
from urllib.parse import urlparse
4546
from uuid import UUID
4647

4748
import attrs
@@ -1650,12 +1651,42 @@ def supervise(
16501651
:param subprocess_logs_to_stdout: Should task logs also be sent to stdout via the main logger.
16511652
:param client: Optional preconfigured client for communication with the server (Mostly for tests).
16521653
:return: Exit code of the process.
1654+
:raises ValueError: If server URL is empty or invalid.
16531655
"""
16541656
# One or the other
16551657
from airflow.sdk.execution_time.secrets_masker import reset_secrets_masker
16561658

1657-
if not client and ((not server) ^ dry_run):
1658-
raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
1659+
if not client:
1660+
if dry_run and server:
1661+
raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
1662+
1663+
if not dry_run:
1664+
if not server:
1665+
raise ValueError(
1666+
"Invalid execution API server URL. Please ensure that a valid URL is configured."
1667+
)
1668+
1669+
try:
1670+
parsed_url = urlparse(server)
1671+
except Exception as e:
1672+
raise ValueError(
1673+
f"Invalid execution API server URL '{server}': {e}. "
1674+
"Please ensure that a valid URL is configured."
1675+
) from e
1676+
1677+
if parsed_url.scheme not in ("http", "https"):
1678+
raise ValueError(
1679+
f"Invalid execution API server URL '{server}': "
1680+
"URL must use http:// or https:// scheme. "
1681+
"Please ensure that a valid URL is configured."
1682+
)
1683+
1684+
if not parsed_url.netloc:
1685+
raise ValueError(
1686+
f"Invalid execution API server URL '{server}': "
1687+
"URL must include a valid host. "
1688+
"Please ensure that a valid URL is configured."
1689+
)
16591690

16601691
if not dag_rel_path:
16611692
raise ValueError("dag_path is required")

task-sdk/tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import os
2121
from pathlib import Path
2222
from typing import TYPE_CHECKING, Any, NoReturn, Protocol
23+
from unittest.mock import patch
2324

2425
import pytest
2526

@@ -271,3 +272,12 @@ def _make_context_dict(
271272
return context.model_dump(exclude_unset=True, mode="json")
272273

273274
return _make_context_dict
275+
276+
277+
@pytest.fixture
278+
def patched_secrets_masker():
279+
from airflow.sdk.execution_time.secrets_masker import SecretsMasker
280+
281+
secrets_masker = SecretsMasker()
282+
with patch("airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=secrets_masker):
283+
yield secrets_masker

task-sdk/tests/task_sdk/execution_time/test_supervisor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import socket
2828
import sys
2929
import time
30+
from contextlib import nullcontext
3031
from operator import attrgetter
3132
from random import randint
3233
from time import sleep
@@ -149,6 +150,58 @@ def client_with_ti_start(make_ti_context):
149150
return client
150151

151152

153+
@pytest.mark.usefixtures("disable_capturing")
154+
class TestSupervisor:
155+
@pytest.mark.parametrize(
156+
"server, dry_run, expectation",
157+
[
158+
("/execution/", False, pytest.raises(ValueError, match="Invalid execution API server URL")),
159+
("", False, pytest.raises(ValueError, match="Invalid execution API server URL")),
160+
("http://localhost:8080", True, pytest.raises(ValueError, match="Can only specify one of")),
161+
(None, True, nullcontext()),
162+
("http://localhost:8080/execution/", False, nullcontext()),
163+
("https://localhost:8080/execution/", False, nullcontext()),
164+
],
165+
)
166+
def test_supervise(
167+
self,
168+
patched_secrets_masker,
169+
server,
170+
dry_run,
171+
expectation,
172+
test_dags_dir,
173+
client_with_ti_start,
174+
):
175+
"""
176+
Test that the supervisor validates server URL and dry_run parameter combinations correctly.
177+
"""
178+
ti = TaskInstance(
179+
id=uuid7(),
180+
task_id="async",
181+
dag_id="super_basic_deferred_run",
182+
run_id="d",
183+
try_number=1,
184+
dag_version_id=uuid7(),
185+
)
186+
187+
bundle_info = BundleInfo(name="my-bundle", version=None)
188+
189+
kw = {
190+
"ti": ti,
191+
"dag_rel_path": "super_basic_deferred_run.py",
192+
"token": "",
193+
"bundle_info": bundle_info,
194+
"dry_run": dry_run,
195+
"server": server,
196+
}
197+
if isinstance(expectation, nullcontext):
198+
kw["client"] = client_with_ti_start
199+
200+
with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir, bundle_info.name)):
201+
with expectation:
202+
supervise(**kw)
203+
204+
152205
@pytest.mark.usefixtures("disable_capturing")
153206
class TestWatchedSubprocess:
154207
@pytest.fixture(autouse=True)

0 commit comments

Comments
 (0)