Skip to content

Commit 1ba8c1d

Browse files
authored
feat: async slack notifier (apache#56685)
1 parent 51016db commit 1ba8c1d

File tree

10 files changed

+204
-23
lines changed

10 files changed

+204
-23
lines changed

providers/slack/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ dependencies = [
6161
"apache-airflow-providers-common-compat>=1.6.1",
6262
"apache-airflow-providers-common-sql>=1.27.0",
6363
"slack-sdk>=3.36.0",
64+
"asgiref>=2.3.0",
6465
]
6566

6667
[dependency-groups]

providers/slack/src/airflow/providers/slack/hooks/slack.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
# KIND, either express or implied. See the License for the
1616
# specific language governing permissions and limitations
1717
# under the License.
18+
"""
19+
Hook for Slack.
20+
21+
.. spelling:word-list::
22+
23+
AsyncSlackResponse
24+
"""
25+
1826
from __future__ import annotations
1927

2028
import json
@@ -27,17 +35,21 @@
2735

2836
from slack_sdk import WebClient
2937
from slack_sdk.errors import SlackApiError
38+
from slack_sdk.web.async_client import AsyncWebClient
3039
from typing_extensions import NotRequired
3140

3241
from airflow.exceptions import AirflowException, AirflowNotFoundException
33-
from airflow.providers.slack.utils import ConnectionExtraConfig
42+
from airflow.providers.slack.utils import ConnectionExtraConfig, get_async_connection
3443
from airflow.providers.slack.version_compat import BaseHook
3544
from airflow.utils.helpers import exactly_one
3645

3746
if TYPE_CHECKING:
3847
from slack_sdk.http_retry import RetryHandler
48+
from slack_sdk.web.async_client import AsyncSlackResponse
3949
from slack_sdk.web.slack_response import SlackResponse
4050

51+
from airflow.providers.slack.version_compat import Connection
52+
4153

4254
class FileUploadTypeDef(TypedDict):
4355
"""
@@ -140,15 +152,20 @@ def __init__(
140152
@cached_property
141153
def client(self) -> WebClient:
142154
"""Get the underlying slack_sdk.WebClient (cached)."""
143-
return WebClient(**self._get_conn_params())
155+
conn = self.get_connection(self.slack_conn_id)
156+
return WebClient(**self._get_conn_params(conn=conn))
157+
158+
async def get_async_client(self) -> AsyncWebClient:
159+
"""Get the underlying `slack_sdk.web.async_client.AsyncWebClient`."""
160+
conn = await get_async_connection(self.slack_conn_id)
161+
return AsyncWebClient(**self._get_conn_params(conn))
144162

145163
def get_conn(self) -> WebClient:
146164
"""Get the underlying slack_sdk.WebClient (cached)."""
147165
return self.client
148166

149-
def _get_conn_params(self) -> dict[str, Any]:
167+
def _get_conn_params(self, conn: Connection) -> dict[str, Any]:
150168
"""Fetch connection params as a dict and merge it with hook parameters."""
151-
conn = self.get_connection(self.slack_conn_id)
152169
if not conn.password:
153170
raise AirflowNotFoundException(
154171
f"Connection ID {self.slack_conn_id!r} does not contain password (Slack API Token)."
@@ -186,6 +203,24 @@ def call(self, api_method: str, **kwargs) -> SlackResponse:
186203
"""
187204
return self.client.api_call(api_method, **kwargs)
188205

206+
async def async_call(self, api_method: str, **kwargs) -> AsyncSlackResponse:
207+
"""
208+
Call Slack WebClient `AsyncWebClient.api_call` with given arguments.
209+
210+
:param api_method: The target Slack API method. e.g. 'chat.postMessage'. Required.
211+
:param http_verb: HTTP Verb. Optional (defaults to 'POST')
212+
:param files: Files to multipart upload. e.g. {imageORfile: file_objectORfile_path}
213+
:param data: The body to attach to the request. If a dictionary is provided,
214+
form-encoding will take place. Optional.
215+
:param params: The URL parameters to append to the URL. Optional.
216+
:param json: JSON for the body to attach to the request. Optional.
217+
:return: The server's response to an HTTP request. Data from the response can be
218+
accessed like a dict. If the response included 'next_cursor' it can be
219+
iterated on to execute subsequent requests.
220+
"""
221+
client = await self.get_async_client()
222+
return await client.api_call(api_method, **kwargs)
223+
189224
def send_file_v2(
190225
self,
191226
*,

providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from slack_sdk.webhook.async_client import AsyncWebhookClient
2828

2929
from airflow.exceptions import AirflowException, AirflowNotFoundException
30-
from airflow.providers.slack.utils import ConnectionExtraConfig
30+
from airflow.providers.slack.utils import ConnectionExtraConfig, get_async_connection
3131
from airflow.providers.slack.version_compat import BaseHook
3232

3333
if TYPE_CHECKING:
@@ -152,9 +152,8 @@ def client(self) -> WebhookClient:
152152
"""Get the underlying slack_sdk.webhook.WebhookClient (cached)."""
153153
return WebhookClient(**self._get_conn_params())
154154

155-
@cached_property
156-
async def async_client(self) -> AsyncWebhookClient:
157-
"""Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient` (cached)."""
155+
async def get_async_client(self) -> AsyncWebhookClient:
156+
"""Get the underlying `slack_sdk.webhook.async_client.AsyncWebhookClient`."""
158157
return AsyncWebhookClient(**await self._async_get_conn_params())
159158

160159
def get_conn(self) -> WebhookClient:
@@ -168,7 +167,7 @@ def _get_conn_params(self) -> dict[str, Any]:
168167

169168
async def _async_get_conn_params(self) -> dict[str, Any]:
170169
"""Fetch connection params as a dict and merge it with hook parameters (async)."""
171-
conn = await self.aget_connection(self.slack_webhook_conn_id)
170+
conn = await get_async_connection(self.slack_webhook_conn_id)
172171
return self._build_conn_params(conn)
173172

174173
def _build_conn_params(self, conn) -> dict[str, Any]:
@@ -251,7 +250,7 @@ async def async_send_dict(self, body: dict[str, Any] | str, *, headers: dict[str
251250
:param headers: Request headers for this request.
252251
"""
253252
body = self._process_body(body)
254-
async_client = await self.async_client
253+
async_client = await self.get_async_client()
255254
return await async_client.send_dict(body, headers=headers)
256255

257256
def send(

providers/slack/src/airflow/providers/slack/notifications/slack.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from airflow.providers.common.compat.notifier import BaseNotifier
2626
from airflow.providers.slack.hooks.slack import SlackHook
27+
from airflow.providers.slack.version_compat import AIRFLOW_V_3_1_PLUS
2728

2829
if TYPE_CHECKING:
2930
from slack_sdk.http_retry import RetryHandler
@@ -71,8 +72,13 @@ def __init__(
7172
retry_handlers: list[RetryHandler] | None = None,
7273
unfurl_links: bool = True,
7374
unfurl_media: bool = True,
75+
**kwargs,
7476
):
75-
super().__init__()
77+
if AIRFLOW_V_3_1_PLUS:
78+
# Support for passing context was added in 3.1.0
79+
super().__init__(**kwargs)
80+
else:
81+
super().__init__()
7682
self.slack_conn_id = slack_conn_id
7783
self.text = text
7884
self.channel = channel
@@ -112,5 +118,19 @@ def notify(self, context):
112118
}
113119
self.hook.call("chat.postMessage", json=api_call_params)
114120

121+
async def async_notify(self, context):
122+
"""Send a message to a Slack Channel (async)."""
123+
api_call_params = {
124+
"channel": self.channel,
125+
"username": self.username,
126+
"text": self.text,
127+
"icon_url": self.icon_url,
128+
"attachments": json.dumps(self.attachments),
129+
"blocks": json.dumps(self.blocks),
130+
"unfurl_links": self.unfurl_links,
131+
"unfurl_media": self.unfurl_media,
132+
}
133+
await self.hook.async_call("chat.postMessage", json=api_call_params)
134+
115135

116136
send_slack_notification = SlackNotifier

providers/slack/src/airflow/providers/slack/utils/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from collections.abc import Sequence
2121
from typing import Any
2222

23+
from asgiref.sync import sync_to_async
24+
25+
from airflow.providers.slack.version_compat import BaseHook, Connection
2326
from airflow.utils.types import NOTSET
2427

2528

@@ -120,3 +123,15 @@ def parse_filename(
120123
if fallback:
121124
return fallback, None
122125
raise ex from None
126+
127+
128+
async def get_async_connection(conn_id: str) -> Connection:
129+
"""
130+
Get an asynchronous Airflow connection that is backwards compatible.
131+
132+
:param conn_id: The provided connection ID.
133+
:returns: Connection
134+
"""
135+
if hasattr(BaseHook, "aget_connection"):
136+
return await BaseHook.aget_connection(conn_id=conn_id)
137+
return await sync_to_async(BaseHook.get_connection)(conn_id=conn_id)

providers/slack/src/airflow/providers/slack/version_compat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
3636
AIRFLOW_V_3_1_PLUS: bool = get_base_airflow_version_tuple() >= (3, 1, 0)
3737

3838
if AIRFLOW_V_3_0_PLUS:
39-
from airflow.sdk import BaseOperator
39+
from airflow.sdk import BaseOperator, Connection
4040
else:
41-
from airflow.models import BaseOperator
41+
from airflow.models import BaseOperator, Connection # type: ignore[assignment]
4242

4343
if AIRFLOW_V_3_1_PLUS:
4444
from airflow.sdk import BaseHook
@@ -50,4 +50,5 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
5050
"AIRFLOW_V_3_1_PLUS",
5151
"BaseHook",
5252
"BaseOperator",
53+
"Connection",
5354
]

providers/slack/tests/unit/slack/hooks/test_slack.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,24 @@ def make_429():
105105
def test_get_token_from_connection(self, conn_id):
106106
"""Test retrieve token from Slack API Connection ID."""
107107
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
108-
assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN
108+
conn = hook.get_connection(hook.slack_conn_id)
109+
assert hook._get_conn_params(conn)["token"] == MOCK_SLACK_API_TOKEN
109110

110111
def test_resolve_token(self):
111112
"""Test that we only use token from Slack API Connection ID."""
112113
with pytest.warns(UserWarning, match="Provide `token` as part of .* parameters is disallowed"):
113114
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID, token="foo-bar")
115+
conn = hook.get_connection(hook.slack_conn_id)
114116
assert "token" not in hook.extra_client_args
115-
assert hook._get_conn_params()["token"] == MOCK_SLACK_API_TOKEN
117+
assert hook._get_conn_params(conn)["token"] == MOCK_SLACK_API_TOKEN
116118

117119
def test_empty_password(self):
118120
"""Test password field defined in the connection."""
119121
hook = SlackHook(slack_conn_id="empty_slack_connection")
122+
conn = hook.get_connection(hook.slack_conn_id)
120123
error_message = r"Connection ID '.*' does not contain password \(Slack API Token\)\."
121124
with pytest.raises(AirflowNotFoundException, match=error_message):
122-
hook._get_conn_params()
125+
hook._get_conn_params(conn)
123126

124127
@pytest.mark.parametrize(
125128
"hook_config,conn_extra,expected",
@@ -228,8 +231,9 @@ def test_client_configuration(
228231

229232
with mock.patch.dict("os.environ", values={test_conn_env: test_conn.get_uri()}):
230233
hook = SlackHook(slack_conn_id=test_conn.conn_id, **hook_config)
234+
conn = hook.get_connection(hook.slack_conn_id)
231235
expected["logger"] = hook.log
232-
conn_params = hook._get_conn_params()
236+
conn_params = hook._get_conn_params(conn)
233237
assert conn_params == expected
234238

235239
client = hook.client
@@ -319,7 +323,8 @@ def test_hook_connection_failed(self, mocked_client, response_data):
319323
def test_backcompat_prefix_works(self, uri, monkeypatch):
320324
monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", uri)
321325
hook = SlackHook(slack_conn_id="my_conn")
322-
params = hook._get_conn_params()
326+
conn = hook.get_connection(hook.slack_conn_id)
327+
params = hook._get_conn_params(conn)
323328
assert params["token"] == "abc"
324329
assert params["timeout"] == 123
325330
assert params["base_url"] == "base_url"
@@ -328,8 +333,9 @@ def test_backcompat_prefix_works(self, uri, monkeypatch):
328333
def test_backcompat_prefix_both_causes_warning(self, monkeypatch):
329334
monkeypatch.setenv("AIRFLOW_CONN_MY_CONN", "a://:abc@?extra__slack__timeout=111&timeout=222")
330335
hook = SlackHook(slack_conn_id="my_conn")
336+
conn = hook.get_connection(hook.slack_conn_id)
331337
with pytest.warns(Warning, match="Using value for `timeout`"):
332-
params = hook._get_conn_params()
338+
params = hook._get_conn_params(conn)
333339
assert params["timeout"] == 222
334340

335341
def test_empty_string_ignored_prefixed(self, monkeypatch):
@@ -340,7 +346,8 @@ def test_empty_string_ignored_prefixed(self, monkeypatch):
340346
),
341347
)
342348
hook = SlackHook(slack_conn_id="my_conn")
343-
params = hook._get_conn_params()
349+
conn = hook.get_connection(hook.slack_conn_id)
350+
params = hook._get_conn_params(conn)
344351
assert "proxy" not in params
345352
assert "base_url" not in params
346353

@@ -350,7 +357,8 @@ def test_empty_string_ignored_non_prefixed(self, monkeypatch):
350357
json.dumps({"password": "hi", "extra": {"base_url": "", "proxy": ""}}),
351358
)
352359
hook = SlackHook(slack_conn_id="my_conn")
353-
params = hook._get_conn_params()
360+
conn = hook.get_connection(hook.slack_conn_id)
361+
params = hook._get_conn_params(conn)
354362
assert "proxy" not in params
355363
assert "base_url" not in params
356364

@@ -539,3 +547,36 @@ def test_send_file_v1_to_v2_multiple_channels(self, channels, expected_calls):
539547
with mock.patch.object(SlackHook, "send_file_v2") as mocked_send_file_v2:
540548
hook.send_file_v1_to_v2(channels=channels, content="Fake")
541549
assert mocked_send_file_v2.call_count == expected_calls
550+
551+
552+
class TestSlackHookAsync:
553+
@pytest.fixture
554+
def mock_get_conn(self):
555+
with mock.patch(
556+
"airflow.providers.slack.hooks.slack.get_async_connection", new_callable=mock.AsyncMock
557+
) as m:
558+
m.return_value = Connection(
559+
conn_id=SLACK_API_DEFAULT_CONN_ID,
560+
conn_type=CONN_TYPE,
561+
password=MOCK_SLACK_API_TOKEN,
562+
)
563+
yield m
564+
565+
@pytest.mark.asyncio
566+
@mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient")
567+
async def test_get_async_client(self, mock_client, mock_get_conn):
568+
"""Test get_async_client creates AsyncWebClient with correct params."""
569+
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
570+
await hook.get_async_client()
571+
mock_get_conn.assert_called()
572+
mock_client.assert_called_once_with(token=MOCK_SLACK_API_TOKEN, logger=mock.ANY)
573+
574+
@pytest.mark.asyncio
575+
@mock.patch("airflow.providers.slack.hooks.slack.AsyncWebClient.api_call", new_callable=mock.AsyncMock)
576+
async def test_async_call(self, mock_api_call, mock_get_conn):
577+
"""Test async_call is called correctly."""
578+
hook = SlackHook(slack_conn_id=SLACK_API_DEFAULT_CONN_ID)
579+
test_api_json = {"channel": "test_channel"}
580+
await hook.async_call("chat.postMessage", json=test_api_json)
581+
mock_get_conn.assert_called()
582+
mock_api_call.assert_called_with("chat.postMessage", json=test_api_json)

providers/slack/tests/unit/slack/hooks/test_slack_webhook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ async def test_async_client(self, mock_async_get_conn_params):
555555
mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL}
556556

557557
hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID)
558-
client = await hook.async_client
558+
client = await hook.get_async_client()
559559

560560
assert isinstance(client, AsyncWebhookClient)
561561
assert client.url == TEST_WEBHOOK_URL

providers/slack/tests/unit/slack/notifications/test_slack.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,31 @@ def test_slack_notifier_unfurl_options(self, mock_slack_hook, create_dag_without
133133
"unfurl_media": False,
134134
},
135135
)
136+
137+
@pytest.mark.asyncio
138+
@mock.patch("airflow.providers.slack.notifications.slack.SlackHook")
139+
async def test_async_slack_notifier(self, mock_slack_hook):
140+
mock_slack_hook.return_value.async_call = mock.AsyncMock()
141+
142+
notifier = send_slack_notification(
143+
text="test",
144+
unfurl_links=False,
145+
unfurl_media=False,
146+
)
147+
148+
await notifier.async_notify({})
149+
150+
mock_slack_hook.return_value.async_call.assert_called_once_with(
151+
"chat.postMessage",
152+
json={
153+
"channel": "#general",
154+
"username": "Airflow",
155+
"text": "test",
156+
"icon_url": "https://raw.githubusercontent.com/apache/airflow/main/airflow-core"
157+
"/src/airflow/ui/public/pin_100.png",
158+
"attachments": "[]",
159+
"blocks": "[]",
160+
"unfurl_links": False,
161+
"unfurl_media": False,
162+
},
163+
)

0 commit comments

Comments
 (0)