Skip to content

Commit 10d957f

Browse files
authored
fix(oso_dagster): add heartbeat to any usage of producer trino (#5475)
* fix(dagster): adds the heartbeat to the trino exporter * fix: make heartbeat naming more generic Update warehouse/oso_dagster/assets/default/trino_automation.py * fix: move all trino heartbeat logic into a single place * fix * fix: further clean up * clean up
1 parent 8e625c8 commit 10d957f

File tree

5 files changed

+185
-146
lines changed

5 files changed

+185
-146
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import dagster as dg
2+
from oso_dagster.config import DagsterConfig
3+
from oso_dagster.factories.common import (
4+
AssetFactoryResponse,
5+
early_resources_asset_factory,
6+
)
7+
from oso_dagster.resources.heartbeat import HeartBeatResource
8+
from oso_dagster.resources.trino import TrinoResource
9+
10+
11+
@early_resources_asset_factory()
12+
def trino_automation_assets() -> AssetFactoryResponse:
13+
# Define a job that checks the heartbeat of "producer_trino". If the
14+
# heartbeat is older than 30 minutes, we scale trino down to zero.
15+
@dg.op
16+
async def trino_heartbeat_checker(
17+
context: dg.OpExecutionContext,
18+
global_config: dg.ResourceParam[DagsterConfig],
19+
trino: dg.ResourceParam[TrinoResource],
20+
heartbeat: dg.ResourceParam[HeartBeatResource],
21+
) -> None:
22+
from datetime import datetime, timedelta, timezone
23+
24+
now = datetime.now(timezone.utc)
25+
26+
last_heartbeat = await heartbeat.get_last_heartbeat_for("producer_trino")
27+
if last_heartbeat is None:
28+
context.log.info(
29+
"No heartbeat found for sqlmesh now ensuring trino shutdown."
30+
)
31+
last_heartbeat = now - timedelta(
32+
minutes=global_config.sqlmesh_trino_ttl_minutes + 1
33+
)
34+
35+
# Only scale down trino if we're in a k8s environment
36+
if not global_config.k8s_enabled:
37+
return
38+
39+
if now - last_heartbeat > timedelta(
40+
minutes=global_config.sqlmesh_trino_ttl_minutes
41+
):
42+
context.log.info(
43+
f"No heartbeat detected for sqlmesh in the last {global_config.sqlmesh_trino_ttl_minutes} minutes. Ensuring that producer trino is scaled down."
44+
)
45+
await trino.ensure_shutdown()
46+
else:
47+
context.log.info("Heartbeat detected for sqlmesh. No action needed.")
48+
49+
# Use the in-process executor for the heartbeat monitor job to avoid
50+
# the overhead of spinning up a new k8s pod in addition to the run launcher
51+
@dg.job(executor_def=dg.in_process_executor)
52+
def trino_heartbeat_monitor_job():
53+
trino_heartbeat_checker()
54+
55+
return AssetFactoryResponse(
56+
assets=[],
57+
jobs=[trino_heartbeat_monitor_job],
58+
schedules=[
59+
dg.ScheduleDefinition(
60+
name="trino_heartbeat_monitor_schedule",
61+
job=trino_heartbeat_monitor_job,
62+
cron_schedule="*/15 * * * *",
63+
)
64+
],
65+
)

warehouse/oso_dagster/assets/sqlmesh/sqlmesh.py

Lines changed: 4 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from oso_dagster.config import DagsterConfig
2525
from oso_dagster.factories import AssetFactoryResponse, cacheable_asset_factory
2626
from oso_dagster.factories.common import CacheableDagsterContext
27-
from oso_dagster.resources.heartbeat import HeartBeatResource
2827
from oso_dagster.resources.trino import TrinoResource
2928
from oso_dagster.utils.asynctools import multiple_async_contexts
3029
from pydantic import BaseModel
@@ -164,7 +163,6 @@ async def sqlmesh_project(
164163
global_config: ResourceParam[DagsterConfig],
165164
sqlmesh: SQLMeshResource,
166165
trino: TrinoResource,
167-
heartbeat: HeartBeatResource,
168166
config: SQLMeshRunConfig,
169167
):
170168
restate_models = config.restate_models[:] if config.restate_models else []
@@ -234,68 +232,21 @@ def run_sqlmesh(
234232

235233
# Trino can either be `local-trino` or `trino`
236234
if "trino" in global_config.sqlmesh_gateway:
237-
# Start a heartbeat to indicate that sqlmesh is running
238-
async with heartbeat.heartbeat(
239-
job_name="sqlmesh", interval_seconds=300, log_override=context.log
235+
async with multiple_async_contexts(
236+
trino=trino.ensure_available(log_override=context.log),
240237
):
241-
async with multiple_async_contexts(
242-
trino=trino.ensure_available(log_override=context.log),
243-
):
244-
for result in run_sqlmesh(context, sqlmesh, config):
245-
yield result
238+
for result in run_sqlmesh(context, sqlmesh, config):
239+
yield result
246240
else:
247241
# If we are not running trino we are using duckdb
248242
for result in run_sqlmesh(context, sqlmesh, config):
249243
yield result
250244

251-
# Define a job that checks the heartbeat of sqlmesh runs. If the
252-
# heartbeat is older than 30 minutes, we scale trino down to zero.
253-
@dg.op
254-
async def sqlmesh_heartbeat_checker(
255-
context: dg.OpExecutionContext,
256-
global_config: ResourceParam[DagsterConfig],
257-
trino: TrinoResource,
258-
heartbeat: HeartBeatResource,
259-
) -> None:
260-
from datetime import datetime, timedelta, timezone
261-
262-
now = datetime.now(timezone.utc)
263-
264-
last_heartbeat = await heartbeat.get_last_heartbeat_for("sqlmesh")
265-
if last_heartbeat is None:
266-
context.log.info(
267-
"No heartbeat found for sqlmesh now ensuring trino shutdown."
268-
)
269-
last_heartbeat = now - timedelta(
270-
minutes=global_config.sqlmesh_trino_ttl_minutes + 1
271-
)
272-
273-
# Only scale down trino if we're in a k8s environment
274-
if not global_config.k8s_enabled:
275-
return
276-
277-
if now - last_heartbeat > timedelta(
278-
minutes=global_config.sqlmesh_trino_ttl_minutes
279-
):
280-
context.log.info(
281-
f"No heartbeat detected for sqlmesh in the last {global_config.sqlmesh_trino_ttl_minutes} minutes. Ensuring that producer trino is scaled down."
282-
)
283-
await trino.ensure_shutdown()
284-
else:
285-
context.log.info("Heartbeat detected for sqlmesh. No action needed.")
286-
287-
# Use the in-process executor for the heartbeat monitor job to avoid
288-
# the overhead of spinning up a new k8s pod in addition to the run launcher
289-
@dg.job(executor_def=dg.in_process_executor)
290-
def sqlmesh_heartbeat_monitor_job():
291-
sqlmesh_heartbeat_checker()
292-
293245
all_assets_selection = AssetSelection.assets(sqlmesh_project)
294246

295247
return AssetFactoryResponse(
296248
assets=[sqlmesh_project],
297249
jobs=[
298-
sqlmesh_heartbeat_monitor_job,
299250
define_asset_job(
300251
name="sqlmesh_all_assets",
301252
selection=all_assets_selection,
@@ -356,13 +307,6 @@ def sqlmesh_heartbeat_monitor_job():
356307
),
357308
),
358309
],
359-
schedules=[
360-
dg.ScheduleDefinition(
361-
name="sqlmesh_heartbeat_monitor_schedule",
362-
job=sqlmesh_heartbeat_monitor_job,
363-
cron_schedule="*/15 * * * *",
364-
)
365-
],
366310
)
367311

368312
return cache_context

warehouse/oso_dagster/definitions/resources.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ def k8s_resource_factory(global_config: DagsterConfig) -> K8sResource | K8sApiRe
188188
@resource_factory("trino")
189189
@time_function(logger)
190190
def trino_resource_factory(
191-
global_config: DagsterConfig, k8s: K8sResource | K8sApiResource
191+
global_config: DagsterConfig,
192+
k8s: K8sResource | K8sApiResource,
193+
heartbeat: HeartBeatResource,
192194
) -> TrinoResource:
193195
if not global_config.k8s_enabled:
194196
return TrinoRemoteResource()
@@ -199,6 +201,7 @@ def trino_resource_factory(
199201
coordinator_deployment_name=global_config.trino_k8s_coordinator_deployment_name,
200202
worker_deployment_name=global_config.trino_k8s_worker_deployment_name,
201203
use_port_forward=global_config.k8s_use_port_forward,
204+
heartbeat=heartbeat,
202205
)
203206

204207

warehouse/oso_dagster/resources/heartbeat.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import asyncio
66
import concurrent.futures
77
import logging
8+
import threading
89
import typing as t
910
from contextlib import asynccontextmanager, suppress
1011
from datetime import datetime, timezone
11-
from queue import Empty, Queue
1212

1313
import aiofiles
1414
import dagster as dg
@@ -22,60 +22,72 @@
2222

2323
def run_beat_loop(
2424
interval_seconds: int,
25-
queue: Queue[bool],
25+
stop: threading.Event,
2626
beat_loop_func: BeatLoopFunc,
2727
beat_loop_kwargs: dict[str, t.Any],
2828
) -> None:
2929
logger.info("Starting heartbeat beat loop")
3030
while True:
3131
logger.info("Running heartbeat beat loop function")
3232
asyncio.run(beat_loop_func(**beat_loop_kwargs))
33-
try:
34-
if queue.get(timeout=interval_seconds):
35-
logger.info("Stopping heartbeat beat loop")
36-
break
37-
except Empty:
38-
continue
33+
if stop.wait(timeout=float(interval_seconds)):
34+
logger.info("Stopping heartbeat beat loop")
35+
break
3936

4037

4138
class HeartBeatResource(dg.ConfigurableResource):
4239
def beat_loop_func(self) -> BeatLoopFunc:
40+
"""Return the function to be called in the heartbeat loop"""
4341
raise NotImplementedError()
4442

4543
def beat_loop_kwargs(self) -> dict[str, t.Any]:
44+
"""Return the kwargs to be passed to the heartbeat loop function"""
4645
return {}
4746

48-
async def get_last_heartbeat_for(self, job_name: str) -> datetime | None:
47+
async def get_last_heartbeat_for(self, name: str) -> datetime | None:
4948
raise NotImplementedError()
5049

51-
async def beat(self, job_name: str) -> None:
50+
async def beat(self, name: str) -> None:
5251
raise NotImplementedError()
5352

5453
@asynccontextmanager
5554
async def heartbeat(
5655
self,
57-
job_name: str,
56+
name: str,
5857
interval_seconds: int = 120,
5958
log_override: logging.Logger | None = None,
6059
) -> t.AsyncIterator[None]:
6160
log_override = log_override or logger
6261
loop = asyncio.get_running_loop()
6362
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
6463
kwargs = self.beat_loop_kwargs().copy()
65-
kwargs.update({"job_name": job_name})
66-
queue = Queue[bool]()
64+
kwargs.update({"heartbeat_name": name})
65+
stop = threading.Event()
66+
67+
# The beat loop must run in a separate thread because dagster's
68+
# async event loop tends to block despite being async. Using a
69+
# thread ensures that the heartbeat will run assuming the process is
70+
# alive and not completely blocked.
71+
#
72+
# So in order to make this work we must pass all of the context we
73+
# need for the beat loop function as the function running in a
74+
# separate thread might not have access to the same context.
75+
# Additionally, this means changing this to be separate processes
76+
# can be done in the future if needed. This is also why the
77+
# functions used for the beat loop are not methods of the resource
78+
# class implementations.
6779
beat_task = loop.run_in_executor(
6880
executor,
6981
run_beat_loop,
7082
interval_seconds,
71-
queue,
83+
stop,
7284
self.beat_loop_func(),
7385
kwargs,
7486
)
7587
try:
7688
yield
7789
finally:
78-
queue.put(True)
90+
stop.set()
7991
beat_task.cancel()
8092
with suppress(asyncio.CancelledError):
8193
await beat_task
@@ -90,10 +102,10 @@ async def async_redis_client(host: str, port: int) -> t.AsyncIterator[Redis]:
90102
await client.aclose()
91103

92104

93-
async def redis_send_heartbeat(*, host: str, port: int, job_name: str) -> None:
105+
async def redis_send_heartbeat(*, host: str, port: int, heartbeat_name: str) -> None:
94106
async with async_redis_client(host, port) as redis_client:
95107
await redis_client.set(
96-
f"heartbeat:{job_name}", datetime.now(timezone.utc).isoformat()
108+
f"heartbeat:{heartbeat_name}", datetime.now(timezone.utc).isoformat()
97109
)
98110

99111

@@ -107,29 +119,29 @@ def beat_loop_func(self) -> BeatLoopFunc:
107119
def beat_loop_kwargs(self) -> dict[str, t.Any]:
108120
return {"host": self.host, "port": self.port}
109121

110-
async def get_last_heartbeat_for(self, job_name: str) -> datetime | None:
122+
async def get_last_heartbeat_for(self, name: str) -> datetime | None:
111123
async with async_redis_client(self.host, self.port) as redis_client:
112-
timestamp = await redis_client.get(f"heartbeat:{job_name}")
113-
logger.info(f"Fetched heartbeat for job {job_name}: {timestamp}")
124+
timestamp = await redis_client.get(f"heartbeat:{name}")
125+
logger.info(f"Fetched heartbeat `{name}`: {timestamp}")
114126
if isinstance(timestamp, str):
115127
return datetime.fromisoformat(timestamp)
116128
elif isinstance(timestamp, bytes):
117129
return datetime.fromisoformat(timestamp.decode("utf-8"))
118130
else:
119131
return None
120132

121-
async def beat(self, job_name: str) -> None:
133+
async def beat(self, name: str) -> None:
122134
return await redis_send_heartbeat(
123-
host=self.host, port=self.port, job_name=job_name
135+
host=self.host, port=self.port, heartbeat_name=name
124136
)
125137

126138

127-
async def filebased_send_heartbeat(*, directory: str, job_name: str) -> None:
139+
async def filebased_send_heartbeat(*, directory: str, heartbeat_name: str) -> None:
128140
from pathlib import Path
129141

130142
import aiofiles
131143

132-
filepath = Path(directory) / f"{job_name}_heartbeat.txt"
144+
filepath = Path(directory) / f"{heartbeat_name}_heartbeat.txt"
133145
async with aiofiles.open(filepath, mode="w") as f:
134146
await f.write(datetime.now(timezone.utc).isoformat())
135147

@@ -139,29 +151,25 @@ class FilebasedHeartBeatResource(HeartBeatResource):
139151

140152
directory: str = Field(description="Directory to store heartbeat files.")
141153

142-
async def get_last_heartbeat_for(self, job_name: str) -> datetime | None:
154+
async def get_last_heartbeat_for(self, name: str) -> datetime | None:
143155
from pathlib import Path
144156

145-
filepath = Path(self.directory) / f"{job_name}_heartbeat.txt"
157+
filepath = Path(self.directory) / f"{name}_heartbeat.txt"
146158
if not filepath.exists():
147159
return None
148160
async with aiofiles.open(filepath, mode="r") as f:
149161
timestamp = await f.read()
150162
return datetime.fromisoformat(timestamp)
151163

152-
async def beat(self, job_name: str) -> None:
153-
from pathlib import Path
154-
155-
import aiofiles
156-
157-
filepath = Path(self.directory) / f"{job_name}_heartbeat.txt"
158-
async with aiofiles.open(filepath, mode="w") as f:
159-
await f.write(datetime.now(timezone.utc).isoformat())
164+
async def beat(self, name: str) -> None:
165+
return await filebased_send_heartbeat(
166+
directory=self.directory, heartbeat_name=name
167+
)
160168

161169
@asynccontextmanager
162170
async def heartbeat(
163171
self,
164-
job_name: str,
172+
name: str,
165173
interval_seconds: int = 120,
166174
log_override: logging.Logger | None = None,
167175
) -> t.AsyncIterator[None]:
@@ -170,14 +178,12 @@ async def heartbeat(
170178
async def beat_loop():
171179
while True:
172180
try:
173-
await self.beat(job_name)
181+
await self.beat(name)
174182
logger_to_use.info(
175-
f"Heartbeat sent for job {job_name} at {datetime.now(timezone.utc).isoformat()}"
183+
f"Heartbeat sent for job {name} at {datetime.now(timezone.utc).isoformat()}"
176184
)
177185
except Exception as e:
178-
logger_to_use.error(
179-
f"Error sending heartbeat for job {job_name}: {e}"
180-
)
186+
logger_to_use.error(f"Error sending heartbeat for job {name}: {e}")
181187
await asyncio.sleep(interval_seconds)
182188

183189
beat_task = asyncio.create_task(beat_loop())

0 commit comments

Comments
 (0)