88import typing as t
99from contextlib import asynccontextmanager , suppress
1010from datetime import datetime , timezone
11+ from queue import Empty , Queue
1112
1213import aiofiles
1314import dagster as dg
1617
1718logger = logging .getLogger (__name__ )
1819
20+ BeatLoopFunc = t .Callable [..., t .Coroutine [None , None , None ]]
21+
22+
23+ def run_beat_loop (
24+ interval_seconds : int ,
25+ queue : Queue [bool ],
26+ beat_loop_func : BeatLoopFunc ,
27+ beat_loop_kwargs : dict [str , t .Any ],
28+ ) -> None :
29+ logger .info ("Starting heartbeat beat loop" )
30+ while True :
31+ logger .info ("Running heartbeat beat loop function" )
32+ asyncio .run (beat_loop_func (** beat_loop_kwargs ))
33+ try :
34+ if queue .get (timeout = interval_seconds ):
35+ logger .info ("Stopping heartbeat beat loop" )
36+ break
37+ except Empty :
38+ continue
39+
1940
2041class HeartBeatResource (dg .ConfigurableResource ):
42+ def beat_loop_func (self ) -> BeatLoopFunc :
43+ raise NotImplementedError ()
44+
45+ def beat_loop_kwargs (self ) -> dict [str , t .Any ]:
46+ return {}
47+
2148 async def get_last_heartbeat_for (self , job_name : str ) -> datetime | None :
2249 raise NotImplementedError ()
2350
@@ -30,51 +57,58 @@ async def heartbeat(
3057 job_name : str ,
3158 interval_seconds : int = 120 ,
3259 log_override : logging .Logger | None = None ,
33- ):
34- """Asynchronously run a heartbeat that updates every `interval_seconds`. We
35- use a separate process which should only live as long as the entire pod
36- for the dagster job is alive.
37- """
38-
60+ ) -> t .AsyncIterator [None ]:
3961 log_override = log_override or logger
40-
41- async def _beat_loop ():
42- log_override .info (
43- f"Starting heartbeat for job { job_name } every { interval_seconds } seconds"
62+ loop = asyncio .get_running_loop ()
63+ with concurrent .futures .ThreadPoolExecutor (max_workers = 1 ) as executor :
64+ kwargs = self .beat_loop_kwargs ().copy ()
65+ kwargs .update ({"job_name" : job_name })
66+ queue = Queue [bool ]()
67+ beat_task = loop .run_in_executor (
68+ executor ,
69+ run_beat_loop ,
70+ interval_seconds ,
71+ queue ,
72+ self .beat_loop_func (),
73+ kwargs ,
4474 )
45- while True :
46- log_override .info (f"Beating heartbeat for job { job_name } " )
47- await self .beat (job_name )
48- await asyncio .sleep (interval_seconds )
75+ try :
76+ yield
77+ finally :
78+ queue .put (True )
79+ beat_task .cancel ()
80+ with suppress (asyncio .CancelledError ):
81+ await beat_task
4982
50- async def _beat_process ():
51- loop = asyncio .get_running_loop ()
52- with concurrent .futures .ProcessPoolExecutor (max_workers = 1 ) as pool :
53- await loop .run_in_executor (pool , asyncio .run , _beat_loop ())
5483
55- task = asyncio .create_task (_beat_process ())
56- try :
57- yield
58- finally :
59- task .cancel ()
60- with suppress (asyncio .CancelledError ):
61- await task
84+ @asynccontextmanager
85+ async def async_redis_client (host : str , port : int ) -> t .AsyncIterator [Redis ]:
86+ client = Redis (host = host , port = port )
87+ try :
88+ yield client
89+ finally :
90+ await client .aclose ()
91+
92+
93+ async def redis_send_heartbeat (* , host : str , port : int , job_name : str ) -> None :
94+ async with async_redis_client (host , port ) as redis_client :
95+ await redis_client .set (
96+ f"heartbeat:{ job_name } " , datetime .now (timezone .utc ).isoformat ()
97+ )
6298
6399
64100class RedisHeartBeatResource (HeartBeatResource ):
65101 host : str = Field (description = "Redis host for heartbeat storage." )
66102 port : int = Field (default = 6379 , description = "Redis port for heartbeat storage." )
67103
68- @asynccontextmanager
69- async def redis_client (self ) -> t .AsyncIterator [Redis ]:
70- client = Redis (host = self .host , port = self .port )
71- try :
72- yield client
73- finally :
74- await client .aclose ()
104+ def beat_loop_func (self ) -> BeatLoopFunc :
105+ return redis_send_heartbeat
106+
107+ def beat_loop_kwargs (self ) -> dict [str , t .Any ]:
108+ return {"host" : self .host , "port" : self .port }
75109
76110 async def get_last_heartbeat_for (self , job_name : str ) -> datetime | None :
77- async with self .redis_client ( ) as redis_client :
111+ async with async_redis_client ( self .host , self . port ) as redis_client :
78112 timestamp = await redis_client .get (f"heartbeat:{ job_name } " )
79113 logger .info (f"Fetched heartbeat for job { job_name } : { timestamp } " )
80114 if isinstance (timestamp , str ):
@@ -85,11 +119,19 @@ async def get_last_heartbeat_for(self, job_name: str) -> datetime | None:
85119 return None
86120
87121 async def beat (self , job_name : str ) -> None :
88- async with self .redis_client () as redis_client :
89- logger .info (f"Setting heartbeat for job { job_name } " )
90- await redis_client .set (
91- f"heartbeat:{ job_name } " , datetime .now (timezone .utc ).isoformat ()
92- )
122+ return await redis_send_heartbeat (
123+ host = self .host , port = self .port , job_name = job_name
124+ )
125+
126+
127+ async def filebased_send_heartbeat (* , directory : str , job_name : str ) -> None :
128+ from pathlib import Path
129+
130+ import aiofiles
131+
132+ filepath = Path (directory ) / f"{ job_name } _heartbeat.txt"
133+ async with aiofiles .open (filepath , mode = "w" ) as f :
134+ await f .write (datetime .now (timezone .utc ).isoformat ())
93135
94136
95137class FilebasedHeartBeatResource (HeartBeatResource ):
@@ -115,3 +157,33 @@ async def beat(self, job_name: str) -> None:
115157 filepath = Path (self .directory ) / f"{ job_name } _heartbeat.txt"
116158 async with aiofiles .open (filepath , mode = "w" ) as f :
117159 await f .write (datetime .now (timezone .utc ).isoformat ())
160+
161+ @asynccontextmanager
162+ async def heartbeat (
163+ self ,
164+ job_name : str ,
165+ interval_seconds : int = 120 ,
166+ log_override : logging .Logger | None = None ,
167+ ) -> t .AsyncIterator [None ]:
168+ logger_to_use = log_override or logger
169+
170+ async def beat_loop ():
171+ while True :
172+ try :
173+ await self .beat (job_name )
174+ logger_to_use .info (
175+ f"Heartbeat sent for job { job_name } at { datetime .now (timezone .utc ).isoformat ()} "
176+ )
177+ except Exception as e :
178+ logger_to_use .error (
179+ f"Error sending heartbeat for job { job_name } : { e } "
180+ )
181+ await asyncio .sleep (interval_seconds )
182+
183+ beat_task = asyncio .create_task (beat_loop ())
184+ try :
185+ yield
186+ finally :
187+ beat_task .cancel ()
188+ with suppress (asyncio .CancelledError ):
189+ await beat_task
0 commit comments