Skip to content

Commit d6d4e74

Browse files
amoghrajeshnailo2c
authored andcommitted
Prioritize secrets backend over DB for retrieving connections (apache#47593)
1 parent a496f10 commit d6d4e74

File tree

3 files changed

+70
-13
lines changed

3 files changed

+70
-13
lines changed

airflow/hooks/base.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from __future__ import annotations
2121

2222
import logging
23-
import sys
2423
from typing import TYPE_CHECKING, Any, Protocol
2524

2625
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -60,18 +59,6 @@ def get_connection(cls, conn_id: str) -> Connection:
6059
:param conn_id: connection id
6160
:return: connection
6261
"""
63-
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
64-
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
65-
# back-compat layer
66-
67-
# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
68-
# and should use the Task SDK API server path
69-
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
70-
# TODO: AIP 72: Add deprecation here once we move this module to task sdk.
71-
from airflow.sdk import Connection as TaskSDKConnection
72-
73-
return TaskSDKConnection.get(conn_id=conn_id)
74-
7562
from airflow.models.connection import Connection
7663

7764
conn = Connection.get_connection_from_secrets(conn_id)

airflow/models/connection.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import json
2121
import logging
2222
import re
23+
import sys
2324
from contextlib import suppress
2425
from json import JSONDecodeError
2526
from typing import Any
@@ -34,6 +35,7 @@
3435
from airflow.models.crypto import get_fernet
3536
from airflow.sdk.execution_time.secrets_masker import mask_secret
3637
from airflow.secrets.cache import SecretCache
38+
from airflow.secrets.metastore import MetastoreBackend
3739
from airflow.utils.helpers import prune_dict
3840
from airflow.utils.log.logging_mixin import LoggingMixin
3941
from airflow.utils.module_loading import import_string
@@ -446,6 +448,8 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection:
446448
"""
447449
Get connection by conn_id.
448450
451+
If `MetastoreBackend` is getting used in the execution context, use Task SDK API.
452+
449453
:param conn_id: connection id
450454
:return: connection
451455
"""
@@ -459,6 +463,18 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection:
459463

460464
# iterate over backends if not in cache (or expired)
461465
for secrets_backend in ensure_secrets_loaded():
466+
if isinstance(secrets_backend, MetastoreBackend):
467+
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
468+
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
469+
# back-compat layer
470+
471+
# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
472+
# and should use the Task SDK API server path
473+
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
474+
# TODO: AIP 72: Add deprecation here once we move this module to task sdk.
475+
from airflow.sdk import Connection as TaskSDKConnection
476+
477+
return TaskSDKConnection.get(conn_id=conn_id)
462478
try:
463479
conn = secrets_backend.get_connection(conn_id=conn_id)
464480
if conn:

tests/hooks/test_base.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,22 @@
1717
# under the License.
1818
from __future__ import annotations
1919

20+
from unittest import mock
21+
22+
import pytest
23+
2024
from airflow.hooks.base import BaseHook
25+
from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection
26+
27+
from tests_common.test_utils.config import conf_vars
28+
29+
30+
@pytest.fixture
31+
def mock_supervisor_comms():
32+
with mock.patch(
33+
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
34+
) as supervisor_comms:
35+
yield supervisor_comms
2136

2237

2338
class TestBaseHook:
@@ -32,3 +47,42 @@ def test_custom_logger_name_is_correctly_set(self):
3247
def test_empty_string_as_logger_name(self):
3348
hook = BaseHook(logger_name="")
3449
assert hook.log.name == "airflow.task.hooks"
50+
51+
def test_get_connection(self, mock_supervisor_comms):
52+
conn = ConnectionResult(
53+
conn_id="test_conn",
54+
conn_type="mysql",
55+
host="mysql",
56+
schema="airflow",
57+
login="root",
58+
password="password",
59+
port=1234,
60+
extra='{"extra_key": "extra_value"}',
61+
)
62+
63+
mock_supervisor_comms.get_message.return_value = conn
64+
65+
hook = BaseHook(logger_name="")
66+
hook.get_connection(conn_id="test_conn")
67+
68+
mock_supervisor_comms.send_request.assert_called_once_with(
69+
msg=GetConnection(conn_id="test_conn"), log=mock.ANY
70+
)
71+
72+
def test_get_connection_secrets_backend_configured(self, mock_supervisor_comms, tmp_path):
73+
path = tmp_path / "conn.env"
74+
path.write_text("CONN_A=mysql://host_a")
75+
76+
with conf_vars(
77+
{
78+
("secrets", "backend"): "airflow.secrets.local_filesystem.LocalFilesystemBackend",
79+
("secrets", "backend_kwargs"): f'{{"connections_file_path": "{path}"}}',
80+
}
81+
):
82+
hook = BaseHook(logger_name="")
83+
retrieved_conn = hook.get_connection(conn_id="CONN_A")
84+
85+
assert retrieved_conn.conn_id == "CONN_A"
86+
87+
mock_supervisor_comms.send_request.assert_not_called()
88+
mock_supervisor_comms.get_message.assert_not_called()

0 commit comments

Comments
 (0)