Skip to content

Commit 00c2359

Browse files
authored
Add AwaitMessageTrigger for Redis PubSub (#52917)
* add redis_pub_sub trigger * add sync_to_async * add docs for Redis trigger * fix = length * fix - length * add Triggers in toctree in index.rst
1 parent 5245f15 commit 00c2359

File tree

8 files changed

+237
-0
lines changed

8 files changed

+237
-0
lines changed

providers/redis/docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
Connection types <connections>
3838
Logging <logging/index>
39+
Triggers <triggers>
3940

4041
.. toctree::
4142
:hidden:

providers/redis/docs/triggers.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
19+
Redis Triggers
20+
==============
21+
22+
.. _howto/triggers:AwaitMessageTrigger:
23+
24+
AwaitMessageTrigger
25+
-------------------
26+
27+
The ``AwaitMessageTrigger`` is a trigger that asynchronously waits for a message to arrive on one or more specified Redis PubSub channels.
28+
29+
For parameter definitions take a look at :class:`~airflow.providers.redis.triggers.redis_await_message.AwaitMessageTrigger`.

providers/redis/provider.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ sensors:
7373
- airflow.providers.redis.sensors.redis_key
7474
- airflow.providers.redis.sensors.redis_pub_sub
7575

76+
triggers:
77+
- integration-name: Redis
78+
python-modules:
79+
- airflow.providers.redis.triggers.redis_await_message
80+
7681
hooks:
7782
- integration-name: Redis
7883
python-modules:

providers/redis/src/airflow/providers/redis/get_provider_info.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def get_provider_info():
4949
],
5050
}
5151
],
52+
"triggers": [
53+
{
54+
"integration-name": "Redis",
55+
"python-modules": ["airflow.providers.redis.triggers.redis_await_message"],
56+
}
57+
],
5258
"hooks": [{"integration-name": "Redis", "python-modules": ["airflow.providers.redis.hooks.redis"]}],
5359
"connection-types": [
5460
{"hook-class-name": "airflow.providers.redis.hooks.redis.RedisHook", "connection-type": "redis"}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
import asyncio
20+
from typing import Any
21+
22+
from asgiref.sync import sync_to_async
23+
24+
from airflow.providers.redis.hooks.redis import RedisHook
25+
from airflow.triggers.base import BaseTrigger, TriggerEvent
26+
27+
28+
class AwaitMessageTrigger(BaseTrigger):
29+
"""
30+
A trigger that waits for a message matching specific criteria to arrive in Redis.
31+
32+
The behavior of this trigger is as follows:
33+
- poll the Redis pubsub for a message, if no message returned, sleep
34+
35+
:param channels: The channels that should be searched for messages
36+
:param redis_conn_id: The connection object to use, defaults to "redis_default"
37+
:param poll_interval: How long the trigger should sleep after reaching the end of the Redis log
38+
(seconds), defaults to 60
39+
"""
40+
41+
def __init__(
42+
self,
43+
channels: list[str] | str,
44+
redis_conn_id: str = "redis_default",
45+
poll_interval: float = 60,
46+
) -> None:
47+
self.channels = channels
48+
self.redis_conn_id = redis_conn_id
49+
self.poll_interval = poll_interval
50+
51+
def serialize(self) -> tuple[str, dict[str, Any]]:
52+
return (
53+
"airflow.providers.redis.triggers.redis_await_message.AwaitMessageTrigger",
54+
{
55+
"channels": self.channels,
56+
"redis_conn_id": self.redis_conn_id,
57+
"poll_interval": self.poll_interval,
58+
},
59+
)
60+
61+
async def run(self):
62+
hook = RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub()
63+
hook.subscribe(self.channels)
64+
65+
async_get_message = sync_to_async(hook.get_message)
66+
while True:
67+
message = await async_get_message()
68+
69+
if message and message["type"] == "message":
70+
yield TriggerEvent(message)
71+
break
72+
await asyncio.sleep(self.poll_interval)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import asyncio
21+
from unittest.mock import patch
22+
23+
import pytest
24+
25+
from airflow.providers.redis.triggers.redis_await_message import AwaitMessageTrigger
26+
27+
28+
class TestAwaitMessageTrigger:
29+
def test_trigger_serialization(self):
30+
trigger = AwaitMessageTrigger(
31+
channels=["test_channel"],
32+
redis_conn_id="redis_default",
33+
poll_interval=30,
34+
)
35+
36+
assert isinstance(trigger, AwaitMessageTrigger)
37+
38+
classpath, kwargs = trigger.serialize()
39+
40+
assert classpath == "airflow.providers.redis.triggers.redis_await_message.AwaitMessageTrigger"
41+
assert kwargs == dict(
42+
channels=["test_channel"],
43+
redis_conn_id="redis_default",
44+
poll_interval=30,
45+
)
46+
47+
@patch("airflow.providers.redis.hooks.redis.RedisHook.get_conn")
48+
@pytest.mark.asyncio
49+
async def test_trigger_run_succeed(self, mock_redis_conn):
50+
trigger = AwaitMessageTrigger(
51+
channels="test",
52+
redis_conn_id="redis_default",
53+
poll_interval=0.0001,
54+
)
55+
56+
mock_redis_conn().pubsub().get_message.return_value = {
57+
"type": "message",
58+
"channel": "test",
59+
"data": "d1",
60+
}
61+
62+
trigger_gen = trigger.run()
63+
task = asyncio.create_task(trigger_gen.__anext__())
64+
event = await task
65+
assert task.done() is True
66+
assert event.payload["data"] == "d1"
67+
assert event.payload["channel"] == "test"
68+
asyncio.get_event_loop().stop()
69+
70+
@patch("airflow.providers.redis.hooks.redis.RedisHook.get_conn")
71+
@pytest.mark.asyncio
72+
async def test_trigger_run_fail(self, mock_redis_conn):
73+
trigger = AwaitMessageTrigger(
74+
channels="test",
75+
redis_conn_id="redis_default",
76+
poll_interval=0.01,
77+
)
78+
79+
mock_redis_conn().pubsub().get_message.return_value = {
80+
"type": "subscribe",
81+
"channel": "test",
82+
"data": "d1",
83+
}
84+
85+
trigger_gen = trigger.run()
86+
task = asyncio.create_task(trigger_gen.__anext__())
87+
await asyncio.sleep(1.0)
88+
assert task.done() is False
89+
task.cancel()
90+
asyncio.get_event_loop().stop()

0 commit comments

Comments
 (0)