Skip to content

Commit c78ea54

Browse files
authored
fix: fixes heartbeat errors that were getting hidden (#5471)
1 parent e2e2d29 commit c78ea54

File tree

3 files changed

+142
-41
lines changed

3 files changed

+142
-41
lines changed

warehouse/oso_dagster/assets/default/heartbeat.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import dagster as dg
6+
from oso_dagster.config import DagsterConfig
67
from oso_dagster.factories.common import (
78
AssetFactoryResponse,
89
early_resources_asset_factory,
@@ -15,7 +16,7 @@ class HeartbeatConfig(dg.Config):
1516

1617

1718
@early_resources_asset_factory()
18-
def heartbeat_factory() -> AssetFactoryResponse:
19+
def heartbeat_factory(global_config: DagsterConfig) -> AssetFactoryResponse:
1920
@dg.op
2021
async def heartbeat_fake_beat(
2122
config: HeartbeatConfig, heartbeat: HeartBeatResource
@@ -44,7 +45,33 @@ async def heartbeat_get_last_heartbeat(
4445
def heartbeat_get_last_heartbeat_job():
4546
heartbeat_get_last_heartbeat()
4647

48+
@dg.asset
49+
async def heartbeat_test_asset(
50+
context: dg.AssetExecutionContext, heartbeat: HeartBeatResource
51+
) -> dg.MaterializeResult:
52+
# Return a basic dataframe with a row of data with 3 columns
53+
54+
async with heartbeat.heartbeat(
55+
"heartbeat_noop_asset", interval_seconds=5, log_override=context.log
56+
):
57+
# Intentionally do some work here to simulate a long running asset
58+
# that consumes a lot of cpu
59+
def fibonacci(n):
60+
if n <= 1:
61+
return n
62+
else:
63+
return fibonacci(n - 1) + fibonacci(n - 2)
64+
65+
for i in range(40):
66+
context.log.info(f"Fibonacci({i}) = {fibonacci(i)}")
67+
context.log.info("Heartbeat asset completed work")
68+
return dg.MaterializeResult(metadata={"info": "Heartbeat asset completed"})
69+
70+
assets = []
71+
if global_config.test_assets_enabled:
72+
assets.append(heartbeat_test_asset)
73+
4774
return AssetFactoryResponse(
48-
assets=[],
75+
assets=assets,
4976
jobs=[heartbeat_fake_beat_job, heartbeat_get_last_heartbeat_job],
5077
)

warehouse/oso_dagster/definitions/resources.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,15 +415,17 @@ def nessie_factory(global_config: DagsterConfig):
415415
@time_function(logger)
416416
def heartbeat_factory(global_config: DagsterConfig) -> HeartBeatResource:
417417
"""Factory function to create a heartbeat resource."""
418-
if global_config.k8s_enabled:
418+
if global_config.k8s_enabled or global_config.redis_host:
419419
assert global_config.redis_host is not None, (
420420
"Redis host must be set for Redis heartbeat."
421421
)
422+
logger.info("Using RedisHeartBeatResource for heartbeat.")
422423
return RedisHeartBeatResource(
423424
host=global_config.redis_host,
424425
port=global_config.redis_port,
425426
)
426427
else:
428+
logger.info("Using FilebasedHeartBeatResource for heartbeat.")
427429
return FilebasedHeartBeatResource(
428430
directory=global_config.dagster_home,
429431
)

warehouse/oso_dagster/resources/heartbeat.py

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing as t
99
from contextlib import asynccontextmanager, suppress
1010
from datetime import datetime, timezone
11+
from queue import Empty, Queue
1112

1213
import aiofiles
1314
import dagster as dg
@@ -16,8 +17,34 @@
1617

1718
logger = logging.getLogger(__name__)
1819

20+
BeatLoopFunc = t.Callable[..., t.Coroutine[None, None, None]]
21+
22+
23+
def run_beat_loop(
24+
interval_seconds: int,
25+
queue: Queue[bool],
26+
beat_loop_func: BeatLoopFunc,
27+
beat_loop_kwargs: dict[str, t.Any],
28+
) -> None:
29+
logger.info("Starting heartbeat beat loop")
30+
while True:
31+
logger.info("Running heartbeat beat loop function")
32+
asyncio.run(beat_loop_func(**beat_loop_kwargs))
33+
try:
34+
if queue.get(timeout=interval_seconds):
35+
logger.info("Stopping heartbeat beat loop")
36+
break
37+
except Empty:
38+
continue
39+
1940

2041
class HeartBeatResource(dg.ConfigurableResource):
42+
def beat_loop_func(self) -> BeatLoopFunc:
43+
raise NotImplementedError()
44+
45+
def beat_loop_kwargs(self) -> dict[str, t.Any]:
46+
return {}
47+
2148
async def get_last_heartbeat_for(self, job_name: str) -> datetime | None:
2249
raise NotImplementedError()
2350

@@ -30,51 +57,58 @@ async def heartbeat(
3057
job_name: str,
3158
interval_seconds: int = 120,
3259
log_override: logging.Logger | None = None,
33-
):
34-
"""Asynchronously run a heartbeat that updates every `interval_seconds`. We
35-
use a separate process which should only live as long as the entire pod
36-
for the dagster job is alive.
37-
"""
38-
60+
) -> t.AsyncIterator[None]:
3961
log_override = log_override or logger
40-
41-
async def _beat_loop():
42-
log_override.info(
43-
f"Starting heartbeat for job {job_name} every {interval_seconds} seconds"
62+
loop = asyncio.get_running_loop()
63+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
64+
kwargs = self.beat_loop_kwargs().copy()
65+
kwargs.update({"job_name": job_name})
66+
queue = Queue[bool]()
67+
beat_task = loop.run_in_executor(
68+
executor,
69+
run_beat_loop,
70+
interval_seconds,
71+
queue,
72+
self.beat_loop_func(),
73+
kwargs,
4474
)
45-
while True:
46-
log_override.info(f"Beating heartbeat for job {job_name}")
47-
await self.beat(job_name)
48-
await asyncio.sleep(interval_seconds)
75+
try:
76+
yield
77+
finally:
78+
queue.put(True)
79+
beat_task.cancel()
80+
with suppress(asyncio.CancelledError):
81+
await beat_task
4982

50-
async def _beat_process():
51-
loop = asyncio.get_running_loop()
52-
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as pool:
53-
await loop.run_in_executor(pool, asyncio.run, _beat_loop())
5483

55-
task = asyncio.create_task(_beat_process())
56-
try:
57-
yield
58-
finally:
59-
task.cancel()
60-
with suppress(asyncio.CancelledError):
61-
await task
84+
@asynccontextmanager
85+
async def async_redis_client(host: str, port: int) -> t.AsyncIterator[Redis]:
86+
client = Redis(host=host, port=port)
87+
try:
88+
yield client
89+
finally:
90+
await client.aclose()
91+
92+
93+
async def redis_send_heartbeat(*, host: str, port: int, job_name: str) -> None:
94+
async with async_redis_client(host, port) as redis_client:
95+
await redis_client.set(
96+
f"heartbeat:{job_name}", datetime.now(timezone.utc).isoformat()
97+
)
6298

6399

64100
class RedisHeartBeatResource(HeartBeatResource):
65101
host: str = Field(description="Redis host for heartbeat storage.")
66102
port: int = Field(default=6379, description="Redis port for heartbeat storage.")
67103

68-
@asynccontextmanager
69-
async def redis_client(self) -> t.AsyncIterator[Redis]:
70-
client = Redis(host=self.host, port=self.port)
71-
try:
72-
yield client
73-
finally:
74-
await client.aclose()
104+
def beat_loop_func(self) -> BeatLoopFunc:
105+
return redis_send_heartbeat
106+
107+
def beat_loop_kwargs(self) -> dict[str, t.Any]:
108+
return {"host": self.host, "port": self.port}
75109

76110
async def get_last_heartbeat_for(self, job_name: str) -> datetime | None:
77-
async with self.redis_client() as redis_client:
111+
async with async_redis_client(self.host, self.port) as redis_client:
78112
timestamp = await redis_client.get(f"heartbeat:{job_name}")
79113
logger.info(f"Fetched heartbeat for job {job_name}: {timestamp}")
80114
if isinstance(timestamp, str):
@@ -85,11 +119,19 @@ async def get_last_heartbeat_for(self, job_name: str) -> datetime | None:
85119
return None
86120

87121
async def beat(self, job_name: str) -> None:
88-
async with self.redis_client() as redis_client:
89-
logger.info(f"Setting heartbeat for job {job_name}")
90-
await redis_client.set(
91-
f"heartbeat:{job_name}", datetime.now(timezone.utc).isoformat()
92-
)
122+
return await redis_send_heartbeat(
123+
host=self.host, port=self.port, job_name=job_name
124+
)
125+
126+
127+
async def filebased_send_heartbeat(*, directory: str, job_name: str) -> None:
128+
from pathlib import Path
129+
130+
import aiofiles
131+
132+
filepath = Path(directory) / f"{job_name}_heartbeat.txt"
133+
async with aiofiles.open(filepath, mode="w") as f:
134+
await f.write(datetime.now(timezone.utc).isoformat())
93135

94136

95137
class FilebasedHeartBeatResource(HeartBeatResource):
@@ -115,3 +157,33 @@ async def beat(self, job_name: str) -> None:
115157
filepath = Path(self.directory) / f"{job_name}_heartbeat.txt"
116158
async with aiofiles.open(filepath, mode="w") as f:
117159
await f.write(datetime.now(timezone.utc).isoformat())
160+
161+
@asynccontextmanager
162+
async def heartbeat(
163+
self,
164+
job_name: str,
165+
interval_seconds: int = 120,
166+
log_override: logging.Logger | None = None,
167+
) -> t.AsyncIterator[None]:
168+
logger_to_use = log_override or logger
169+
170+
async def beat_loop():
171+
while True:
172+
try:
173+
await self.beat(job_name)
174+
logger_to_use.info(
175+
f"Heartbeat sent for job {job_name} at {datetime.now(timezone.utc).isoformat()}"
176+
)
177+
except Exception as e:
178+
logger_to_use.error(
179+
f"Error sending heartbeat for job {job_name}: {e}"
180+
)
181+
await asyncio.sleep(interval_seconds)
182+
183+
beat_task = asyncio.create_task(beat_loop())
184+
try:
185+
yield
186+
finally:
187+
beat_task.cancel()
188+
with suppress(asyncio.CancelledError):
189+
await beat_task

0 commit comments

Comments
 (0)