Skip to content

Commit 6c6a4a6

Browse files
authored
Add a backup implementation in AWS MwaaHook for calling the MWAA API (apache#47035)
The existing implementation doesn't work when the user doesn't have `airflow:InvokeRestApi` permission in their IAM policy or when they make more than 10 transactions per second. This implementation mitigates those issues by using a session token approach. However, my existing implementation is still used by default because it is simpler. Some context here: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html
1 parent b5038ef commit 6c6a4a6

File tree

3 files changed

+238
-71
lines changed

3 files changed

+238
-71
lines changed

docs/spelling_wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,6 +1863,7 @@ urls
18631863
useHCatalog
18641864
useLegacySQL
18651865
useQueryCache
1866+
userguide
18661867
userId
18671868
userpass
18681869
usr

providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from __future__ import annotations
2020

21+
import requests
2122
from botocore.exceptions import ClientError
2223

2324
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
2930
3031
Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") <MWAA.Client>`
3132
33+
If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the hook will use a fallback method
34+
that uses the AWS credential to generate a local web login token for the Airflow Web UI and then directly
35+
make requests to the Airflow API. This fallback method can be set as the default (and only) method used by
36+
setting `generate_local_token` to True. Learn more here:
37+
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
38+
3239
Additional arguments (such as ``aws_conn_id``) may be specified and
3340
are passed down to the underlying AwsBaseHook.
3441
@@ -47,6 +54,7 @@ def invoke_rest_api(
4754
method: str,
4855
body: dict | None = None,
4956
query_params: dict | None = None,
57+
generate_local_token: bool = False,
5058
) -> dict:
5159
"""
5260
Invoke the REST API on the Airflow webserver with the specified inputs.
@@ -56,30 +64,86 @@ def invoke_rest_api(
5664
5765
:param env_name: name of the MWAA environment
5866
:param path: Apache Airflow REST API endpoint path to be called
59-
:param method: HTTP method used for making Airflow REST API calls
67+
:param method: HTTP method used for making Airflow REST API calls: 'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
6068
:param body: Request body for the Apache Airflow REST API call
6169
:param query_params: Query parameters to be included in the Apache Airflow REST API call
70+
:param generate_local_token: If True, only the local web token method is used without trying boto's
71+
`invoke_rest_api` first. If False, the local web token method is used as a fallback after trying
72+
boto's `invoke_rest_api`
6273
"""
63-
body = body or {}
74+
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
75+
body = {k: v for k, v in body.items() if v is not None} if body else {}
76+
query_params = query_params or {}
6477
api_kwargs = {
6578
"Name": env_name,
6679
"Path": path,
6780
"Method": method,
68-
# Filter out keys with None values because Airflow REST API doesn't accept requests otherwise
69-
"Body": {k: v for k, v in body.items() if v is not None},
70-
"QueryParameters": query_params if query_params else {},
81+
"Body": body,
82+
"QueryParameters": query_params,
7183
}
84+
85+
if generate_local_token:
86+
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
87+
7288
try:
73-
result = self.conn.invoke_rest_api(**api_kwargs)
89+
response = self.conn.invoke_rest_api(**api_kwargs)
7490
# ResponseMetadata is removed because it contains data that is either very unlikely to be useful
7591
# in XComs and logs, or redundant given the data already included in the response
76-
result.pop("ResponseMetadata", None)
77-
return result
92+
response.pop("ResponseMetadata", None)
93+
return response
94+
7895
except ClientError as e:
79-
to_log = e.response
80-
# ResponseMetadata and Error are removed because they contain data that is either very unlikely to
81-
# be useful in XComs and logs, or redundant given the data already included in the response
82-
to_log.pop("ResponseMetadata", None)
83-
to_log.pop("Error", None)
84-
self.log.error(to_log)
85-
raise e
96+
if (
97+
e.response["Error"]["Code"] == "AccessDeniedException"
98+
and "Airflow role" in e.response["Error"]["Message"]
99+
):
100+
self.log.info(
101+
"Access Denied due to missing airflow:InvokeRestApi in IAM policy. Trying again by generating local token..."
102+
)
103+
return self._invoke_rest_api_using_local_session_token(**api_kwargs)
104+
else:
105+
to_log = e.response
106+
# ResponseMetadata is removed because it contains data that is either very unlikely to be
107+
# useful in XComs and logs, or redundant given the data already included in the response
108+
to_log.pop("ResponseMetadata", None)
109+
self.log.error(to_log)
110+
raise
111+
112+
def _invoke_rest_api_using_local_session_token(
113+
self,
114+
**api_kwargs,
115+
) -> dict:
116+
try:
117+
session, hostname = self._get_session_conn(api_kwargs["Name"])
118+
119+
response = session.request(
120+
method=api_kwargs["Method"],
121+
url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
122+
params=api_kwargs["QueryParameters"],
123+
json=api_kwargs["Body"],
124+
timeout=10,
125+
)
126+
response.raise_for_status()
127+
128+
except requests.HTTPError as e:
129+
self.log.error(e.response.json())
130+
raise
131+
132+
return {
133+
"RestApiStatusCode": response.status_code,
134+
"RestApiResponse": response.json(),
135+
}
136+
137+
# Based on: https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
138+
def _get_session_conn(self, env_name: str) -> tuple:
139+
create_token_response = self.conn.create_web_login_token(Name=env_name)
140+
web_server_hostname = create_token_response["WebServerHostname"]
141+
web_token = create_token_response["WebToken"]
142+
143+
login_url = f"https://{web_server_hostname}/aws_mwaa/login"
144+
login_payload = {"token": web_token}
145+
session = requests.Session()
146+
login_response = session.post(login_url, data=login_payload, timeout=10)
147+
login_response.raise_for_status()
148+
149+
return session, web_server_hostname

providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py

Lines changed: 158 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from unittest import mock
2020

2121
import pytest
22+
import requests
2223
from botocore.exceptions import ClientError
2324
from moto import mock_aws
2425

@@ -27,16 +28,161 @@
2728
ENV_NAME = "test_env"
2829
PATH = "/dags/test_dag/dagRuns"
2930
METHOD = "POST"
31+
BODY: dict = {"conf": {}}
3032
QUERY_PARAMS = {"limit": 30}
33+
HOSTNAME = "example.com"
3134

3235

3336
class TestMwaaHook:
37+
@pytest.fixture
38+
def mock_conn(self):
39+
with mock.patch.object(MwaaHook, "conn") as m:
40+
yield m
41+
3442
def setup_method(self):
3543
self.hook = MwaaHook()
3644

37-
# these example responses are included here instead of as a constant because the hook will mutate
38-
# responses causing subsequent tests to fail
39-
self.example_responses = {
45+
def test_init(self):
46+
assert self.hook.client_type == "mwaa"
47+
48+
@mock_aws
49+
def test_get_conn(self):
50+
assert self.hook.conn is not None
51+
52+
@pytest.mark.parametrize(
53+
"body",
54+
[
55+
pytest.param(None, id="no_body"),
56+
pytest.param(BODY, id="non_empty_body"),
57+
],
58+
)
59+
def test_invoke_rest_api_success(self, body, mock_conn, example_responses):
60+
boto_invoke_mock = mock.MagicMock(return_value=example_responses["success"])
61+
mock_conn.invoke_rest_api = boto_invoke_mock
62+
63+
retval = self.hook.invoke_rest_api(
64+
env_name=ENV_NAME, path=PATH, method=METHOD, body=body, query_params=QUERY_PARAMS
65+
)
66+
kwargs_to_assert = {
67+
"Name": ENV_NAME,
68+
"Path": PATH,
69+
"Method": METHOD,
70+
"Body": body if body else {},
71+
"QueryParameters": QUERY_PARAMS,
72+
}
73+
boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
74+
mock_conn.create_web_login_token.assert_not_called()
75+
assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"}
76+
77+
def test_invoke_rest_api_failure(self, mock_conn, example_responses):
78+
error = ClientError(error_response=example_responses["failure"], operation_name="invoke_rest_api")
79+
mock_conn.invoke_rest_api = mock.MagicMock(side_effect=error)
80+
mock_error_log = mock.MagicMock()
81+
self.hook.log.error = mock_error_log
82+
83+
with pytest.raises(ClientError) as caught_error:
84+
self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, method=METHOD)
85+
86+
assert caught_error.value == error
87+
mock_conn.create_web_login_token.assert_not_called()
88+
expected_log = {k: v for k, v in example_responses["failure"].items() if k != "ResponseMetadata"}
89+
mock_error_log.assert_called_once_with(expected_log)
90+
91+
@pytest.mark.parametrize("generate_local_token", [pytest.param(True), pytest.param(False)])
92+
@mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
93+
def test_invoke_rest_api_local_token_parameter(
94+
self, mock_create_session, generate_local_token, mock_conn
95+
):
96+
self.hook.invoke_rest_api(
97+
env_name=ENV_NAME, path=PATH, method=METHOD, generate_local_token=generate_local_token
98+
)
99+
if generate_local_token:
100+
mock_conn.invoke_rest_api.assert_not_called()
101+
mock_conn.create_web_login_token.assert_called_once()
102+
mock_create_session.assert_called_once()
103+
mock_create_session.return_value.request.assert_called_once()
104+
else:
105+
mock_conn.invoke_rest_api.assert_called_once()
106+
107+
@mock.patch.object(MwaaHook, "_get_session_conn")
108+
def test_invoke_rest_api_fallback_success_when_iam_fails(
109+
self, mock_get_session_conn, mock_conn, example_responses
110+
):
111+
boto_invoke_error = ClientError(
112+
error_response=example_responses["missingIamRole"], operation_name="invoke_rest_api"
113+
)
114+
mock_conn.invoke_rest_api = mock.MagicMock(side_effect=boto_invoke_error)
115+
116+
kwargs_to_assert = {
117+
"method": METHOD,
118+
"url": f"https://{HOSTNAME}/api/v1{PATH}",
119+
"params": QUERY_PARAMS,
120+
"json": BODY,
121+
"timeout": 10,
122+
}
123+
124+
mock_response = mock.MagicMock()
125+
mock_response.status_code = example_responses["success"]["RestApiStatusCode"]
126+
mock_response.json.return_value = example_responses["success"]["RestApiResponse"]
127+
mock_session = mock.MagicMock()
128+
mock_session.request.return_value = mock_response
129+
130+
mock_get_session_conn.return_value = (mock_session, HOSTNAME)
131+
132+
retval = self.hook.invoke_rest_api(
133+
env_name=ENV_NAME, path=PATH, method=METHOD, body=BODY, query_params=QUERY_PARAMS
134+
)
135+
136+
mock_session.request.assert_called_once_with(**kwargs_to_assert)
137+
mock_response.raise_for_status.assert_called_once()
138+
assert retval == {k: v for k, v in example_responses["success"].items() if k != "ResponseMetadata"}
139+
140+
@mock.patch.object(MwaaHook, "_get_session_conn")
141+
def test_invoke_rest_api_using_local_session_token_failure(
142+
self, mock_get_session_conn, example_responses
143+
):
144+
mock_response = mock.MagicMock()
145+
mock_response.json.return_value = example_responses["failure"]["RestApiResponse"]
146+
error = requests.HTTPError(response=mock_response)
147+
mock_response.raise_for_status.side_effect = error
148+
149+
mock_session = mock.MagicMock()
150+
mock_session.request.return_value = mock_response
151+
152+
mock_get_session_conn.return_value = (mock_session, HOSTNAME)
153+
154+
mock_error_log = mock.MagicMock()
155+
self.hook.log.error = mock_error_log
156+
157+
with pytest.raises(requests.HTTPError) as caught_error:
158+
self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, method=METHOD, generate_local_token=True)
159+
160+
assert caught_error.value == error
161+
mock_error_log.assert_called_once_with(example_responses["failure"]["RestApiResponse"])
162+
163+
@mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
164+
def test_get_session_conn(self, mock_create_session, mock_conn):
165+
token = "token"
166+
mock_conn.create_web_login_token.return_value = {"WebServerHostname": HOSTNAME, "WebToken": token}
167+
login_url = f"https://{HOSTNAME}/aws_mwaa/login"
168+
login_payload = {"token": token}
169+
170+
mock_session = mock.MagicMock()
171+
mock_create_session.return_value = mock_session
172+
173+
retval = self.hook._get_session_conn(env_name=ENV_NAME)
174+
175+
mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME)
176+
mock_create_session.assert_called_once_with()
177+
mock_session.post.assert_called_once_with(login_url, data=login_payload, timeout=10)
178+
mock_session.post.return_value.raise_for_status.assert_called_once()
179+
180+
assert retval == (mock_session, HOSTNAME)
181+
182+
@pytest.fixture
183+
def example_responses(self):
184+
"""Fixture for test responses to avoid mutation between tests."""
185+
return {
40186
"success": {
41187
"ResponseMetadata": {
42188
"RequestId": "some ID",
@@ -73,57 +219,13 @@ def setup_method(self):
73219
"type": "https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound",
74220
},
75221
},
222+
"missingIamRole": {
223+
"Error": {"Message": "No Airflow role granted in IAM.", "Code": "AccessDeniedException"},
224+
"ResponseMetadata": {
225+
"RequestId": "some ID",
226+
"HTTPStatusCode": 403,
227+
"HTTPHeaders": {"header1": "value1"},
228+
"RetryAttempts": 0,
229+
},
230+
},
76231
}
77-
78-
def test_init(self):
79-
assert self.hook.client_type == "mwaa"
80-
81-
@mock_aws
82-
def test_get_conn(self):
83-
assert self.hook.conn is not None
84-
85-
@pytest.mark.parametrize(
86-
"body",
87-
[
88-
pytest.param(None, id="no_body"),
89-
pytest.param({"conf": {}}, id="non_empty_body"),
90-
],
91-
)
92-
@mock.patch.object(MwaaHook, "conn")
93-
def test_invoke_rest_api_success(self, mock_conn, body) -> None:
94-
boto_invoke_mock = mock.MagicMock(return_value=self.example_responses["success"])
95-
mock_conn.invoke_rest_api = boto_invoke_mock
96-
97-
retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, QUERY_PARAMS)
98-
kwargs_to_assert = {
99-
"Name": ENV_NAME,
100-
"Path": PATH,
101-
"Method": METHOD,
102-
"Body": body if body else {},
103-
"QueryParameters": QUERY_PARAMS,
104-
}
105-
boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
106-
assert retval == {
107-
k: v for k, v in self.example_responses["success"].items() if k != "ResponseMetadata"
108-
}
109-
110-
@mock.patch.object(MwaaHook, "conn")
111-
def test_invoke_rest_api_failure(self, mock_conn) -> None:
112-
error = ClientError(
113-
error_response=self.example_responses["failure"], operation_name="invoke_rest_api"
114-
)
115-
boto_invoke_mock = mock.MagicMock(side_effect=error)
116-
mock_conn.invoke_rest_api = boto_invoke_mock
117-
mock_log = mock.MagicMock()
118-
self.hook.log.error = mock_log
119-
120-
with pytest.raises(ClientError) as caught_error:
121-
self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD)
122-
123-
assert caught_error.value == error
124-
expected_log = {
125-
k: v
126-
for k, v in self.example_responses["failure"].items()
127-
if k != "ResponseMetadata" and k != "Error"
128-
}
129-
mock_log.assert_called_once_with(expected_log)

0 commit comments

Comments
 (0)