Skip to content

Commit ad6423d

Browse files
authored
Optimize job submissions loading (#3466)
* Optimize process_running_jobs select * Optimize process_runs select * Add test_calculates_retry_duration_since_last_successful_submission * Fix _should_retry_job
1 parent 4432cdf commit ad6423d

File tree

3 files changed

+171
-54
lines changed

3 files changed

+171
-54
lines changed

src/dstack/_internal/server/background/tasks/process_running_jobs.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from datetime import timedelta
66
from typing import Dict, List, Optional
77

8-
from sqlalchemy import select
8+
from sqlalchemy import and_, func, select
99
from sqlalchemy.ext.asyncio import AsyncSession
10-
from sqlalchemy.orm import joinedload, load_only
10+
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only
1111

1212
from dstack._internal import settings
1313
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
@@ -139,25 +139,8 @@ async def _process_next_running_job():
139139

140140

141141
async def _process_running_job(session: AsyncSession, job_model: JobModel):
142-
# Refetch to load related attributes.
143-
res = await session.execute(
144-
select(JobModel)
145-
.where(JobModel.id == job_model.id)
146-
.options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
147-
.options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak))
148-
.execution_options(populate_existing=True)
149-
)
150-
job_model = res.unique().scalar_one()
151-
res = await session.execute(
152-
select(RunModel)
153-
.where(RunModel.id == job_model.run_id)
154-
.options(joinedload(RunModel.project))
155-
.options(joinedload(RunModel.user))
156-
.options(joinedload(RunModel.repo))
157-
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
158-
.options(joinedload(RunModel.jobs))
159-
)
160-
run_model = res.unique().scalar_one()
142+
job_model = await _refetch_job_model(session, job_model)
143+
run_model = await _fetch_run_model(session, job_model.run_id)
161144
repo_model = run_model.repo
162145
project = run_model.project
163146
run = run_model_to_run(run_model, include_sensitive=True)
@@ -421,6 +404,53 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
421404
await session.commit()
422405

423406

407+
async def _refetch_job_model(session: AsyncSession, job_model: JobModel) -> JobModel:
408+
res = await session.execute(
409+
select(JobModel)
410+
.where(JobModel.id == job_model.id)
411+
.options(joinedload(JobModel.instance).joinedload(InstanceModel.project))
412+
.options(joinedload(JobModel.probes).load_only(ProbeModel.success_streak))
413+
.execution_options(populate_existing=True)
414+
)
415+
return res.unique().scalar_one()
416+
417+
418+
async def _fetch_run_model(session: AsyncSession, run_id: uuid.UUID) -> RunModel:
419+
# Select only latest submissions for every job.
420+
latest_submissions_sq = (
421+
select(
422+
JobModel.run_id.label("run_id"),
423+
JobModel.replica_num.label("replica_num"),
424+
JobModel.job_num.label("job_num"),
425+
func.max(JobModel.submission_num).label("max_submission_num"),
426+
)
427+
.where(JobModel.run_id == run_id)
428+
.group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num)
429+
.subquery()
430+
)
431+
job_alias = aliased(JobModel)
432+
res = await session.execute(
433+
select(RunModel)
434+
.where(RunModel.id == run_id)
435+
.join(job_alias, job_alias.run_id == RunModel.id)
436+
.join(
437+
latest_submissions_sq,
438+
onclause=and_(
439+
job_alias.run_id == latest_submissions_sq.c.run_id,
440+
job_alias.replica_num == latest_submissions_sq.c.replica_num,
441+
job_alias.job_num == latest_submissions_sq.c.job_num,
442+
job_alias.submission_num == latest_submissions_sq.c.max_submission_num,
443+
),
444+
)
445+
.options(joinedload(RunModel.project))
446+
.options(joinedload(RunModel.user))
447+
.options(joinedload(RunModel.repo))
448+
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
449+
.options(contains_eager(RunModel.jobs, alias=job_alias))
450+
)
451+
return res.unique().scalar_one()
452+
453+
424454
async def _wait_for_instance_provisioning_data(session: AsyncSession, job_model: JobModel):
425455
"""
426456
This function will be called until instance IP address appears

src/dstack/_internal/server/background/tasks/process_runs.py

Lines changed: 78 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import datetime
33
from typing import List, Optional, Set, Tuple
44

5-
from sqlalchemy import and_, or_, select
5+
from sqlalchemy import and_, func, or_, select
66
from sqlalchemy.ext.asyncio import AsyncSession
7-
from sqlalchemy.orm import joinedload, load_only, selectinload
7+
from sqlalchemy.orm import aliased, contains_eager, joinedload, load_only
88

99
import dstack._internal.server.services.services.autoscalers as autoscalers
1010
from dstack._internal.core.errors import ServerError
@@ -33,6 +33,7 @@
3333
get_job_specs_from_run_spec,
3434
group_jobs_by_replica_latest,
3535
is_master_job,
36+
job_model_to_job_submission,
3637
switch_job_status,
3738
)
3839
from dstack._internal.server.services.locking import get_locker
@@ -144,22 +145,7 @@ async def _process_next_run():
144145

145146

146147
async def _process_run(session: AsyncSession, run_model: RunModel):
147-
# Refetch to load related attributes.
148-
res = await session.execute(
149-
select(RunModel)
150-
.where(RunModel.id == run_model.id)
151-
.execution_options(populate_existing=True)
152-
.options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name))
153-
.options(joinedload(RunModel.user).load_only(UserModel.name))
154-
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
155-
.options(
156-
selectinload(RunModel.jobs)
157-
.joinedload(JobModel.instance)
158-
.load_only(InstanceModel.fleet_id)
159-
)
160-
.execution_options(populate_existing=True)
161-
)
162-
run_model = res.unique().scalar_one()
148+
run_model = await _refetch_run_model(session, run_model)
163149
logger.debug("%s: processing run", fmt(run_model))
164150
try:
165151
if run_model.status == RunStatus.PENDING:
@@ -181,6 +167,46 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
181167
await session.commit()
182168

183169

170+
async def _refetch_run_model(session: AsyncSession, run_model: RunModel) -> RunModel:
171+
# Select only latest submissions for every job.
172+
latest_submissions_sq = (
173+
select(
174+
JobModel.run_id.label("run_id"),
175+
JobModel.replica_num.label("replica_num"),
176+
JobModel.job_num.label("job_num"),
177+
func.max(JobModel.submission_num).label("max_submission_num"),
178+
)
179+
.where(JobModel.run_id == run_model.id)
180+
.group_by(JobModel.run_id, JobModel.replica_num, JobModel.job_num)
181+
.subquery()
182+
)
183+
job_alias = aliased(JobModel)
184+
res = await session.execute(
185+
select(RunModel)
186+
.where(RunModel.id == run_model.id)
187+
.outerjoin(latest_submissions_sq, latest_submissions_sq.c.run_id == RunModel.id)
188+
.outerjoin(
189+
job_alias,
190+
onclause=and_(
191+
job_alias.run_id == latest_submissions_sq.c.run_id,
192+
job_alias.replica_num == latest_submissions_sq.c.replica_num,
193+
job_alias.job_num == latest_submissions_sq.c.job_num,
194+
job_alias.submission_num == latest_submissions_sq.c.max_submission_num,
195+
),
196+
)
197+
.options(joinedload(RunModel.project).load_only(ProjectModel.id, ProjectModel.name))
198+
.options(joinedload(RunModel.user).load_only(UserModel.name))
199+
.options(joinedload(RunModel.fleet).load_only(FleetModel.id, FleetModel.name))
200+
.options(
201+
contains_eager(RunModel.jobs, alias=job_alias)
202+
.joinedload(JobModel.instance)
203+
.load_only(InstanceModel.fleet_id)
204+
)
205+
.execution_options(populate_existing=True)
206+
)
207+
return res.unique().scalar_one()
208+
209+
184210
async def _process_pending_run(session: AsyncSession, run_model: RunModel):
185211
"""Jobs are not created yet"""
186212
run = run_model_to_run(run_model)
@@ -294,7 +320,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
294320
and job_model.termination_reason
295321
not in {JobTerminationReason.DONE_BY_RUNNER, JobTerminationReason.SCALED_DOWN}
296322
):
297-
current_duration = _should_retry_job(run, job, job_model)
323+
current_duration = await _should_retry_job(session, run, job, job_model)
298324
if current_duration is None:
299325
replica_statuses.add(RunStatus.FAILED)
300326
run_termination_reasons.add(RunTerminationReason.JOB_FAILED)
@@ -552,19 +578,44 @@ def _has_out_of_date_replicas(run: RunModel) -> bool:
552578
return False
553579

554580

555-
def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datetime.timedelta]:
581+
async def _should_retry_job(
582+
session: AsyncSession,
583+
run: Run,
584+
job: Job,
585+
job_model: JobModel,
586+
) -> Optional[datetime.timedelta]:
556587
"""
557588
Checks if the job should be retried.
558589
Returns the current duration of retrying if retry is enabled.
590+
Retrying duration is calculated as the time since `last_processed_at`
591+
of the latest provisioned submission.
559592
"""
560593
if job.job_spec.retry is None:
561594
return None
562595

563596
last_provisioned_submission = None
564-
for job_submission in reversed(job.job_submissions):
565-
if job_submission.job_provisioning_data is not None:
566-
last_provisioned_submission = job_submission
567-
break
597+
if len(job.job_submissions) > 0:
598+
last_submission = job.job_submissions[-1]
599+
if last_submission.job_provisioning_data is not None:
600+
last_provisioned_submission = last_submission
601+
else:
602+
# The caller passes at most one latest submission in job.job_submissions, so check the db.
603+
res = await session.execute(
604+
select(JobModel)
605+
.where(
606+
JobModel.run_id == job_model.run_id,
607+
JobModel.replica_num == job_model.replica_num,
608+
JobModel.job_num == job_model.job_num,
609+
JobModel.job_provisioning_data.is_not(None),
610+
)
611+
.order_by(JobModel.last_processed_at.desc())
612+
.limit(1)
613+
)
614+
last_provisioned_submission_model = res.scalar()
615+
if last_provisioned_submission_model is not None:
616+
last_provisioned_submission = job_model_to_job_submission(
617+
last_provisioned_submission_model
618+
)
568619

569620
if (
570621
job_model.termination_reason is not None
@@ -574,13 +625,10 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
574625
):
575626
return common.get_current_datetime() - run.submitted_at
576627

577-
if last_provisioned_submission is None:
578-
return None
579-
580628
if (
581-
last_provisioned_submission.termination_reason is not None
582-
and JobTerminationReason(last_provisioned_submission.termination_reason).to_retry_event()
583-
in job.job_spec.retry.on_events
629+
job_model.termination_reason is not None
630+
and job_model.termination_reason.to_retry_event() in job.job_spec.retry.on_events
631+
and last_provisioned_submission is not None
584632
):
585633
return common.get_current_datetime() - last_provisioned_submission.last_processed_at
586634

src/tests/_internal/server/background/tasks/test_process_runs.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import datetime
22
from collections.abc import Iterable
3-
from typing import Union, cast
3+
from typing import Optional, Union, cast
44
from unittest.mock import patch
55

66
import pytest
@@ -15,7 +15,7 @@
1515
TaskConfiguration,
1616
)
1717
from dstack._internal.core.models.instances import InstanceStatus
18-
from dstack._internal.core.models.profiles import Profile, ProfileRetry, Schedule
18+
from dstack._internal.core.models.profiles import Profile, ProfileRetry, RetryEvent, Schedule
1919
from dstack._internal.core.models.resources import Range
2020
from dstack._internal.core.models.runs import (
2121
JobSpec,
@@ -48,6 +48,7 @@ async def make_run(
4848
deployment_num: int = 0,
4949
image: str = "ubuntu:latest",
5050
probes: Iterable[ProbeConfig] = (),
51+
retry: Optional[ProfileRetry] = None,
5152
) -> RunModel:
5253
project = await create_project(session=session)
5354
user = await create_user(session=session)
@@ -58,7 +59,7 @@ async def make_run(
5859
run_name = "test-run"
5960
profile = Profile(
6061
name="test-profile",
61-
retry=True,
62+
retry=retry or True,
6263
)
6364
run_spec = get_run_spec(
6465
repo_id=repo.name,
@@ -230,6 +231,44 @@ async def test_retry_running_to_failed(self, test_db, session: AsyncSession):
230231
assert run.status == RunStatus.TERMINATING
231232
assert run.termination_reason == RunTerminationReason.JOB_FAILED
232233

234+
@pytest.mark.asyncio
235+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
236+
async def test_calculates_retry_duration_since_last_successful_submission(
237+
self, test_db, session: AsyncSession
238+
):
239+
run = await make_run(
240+
session,
241+
status=RunStatus.RUNNING,
242+
replicas=1,
243+
retry=ProfileRetry(duration=300, on_events=[RetryEvent.NO_CAPACITY]),
244+
)
245+
now = run.submitted_at + datetime.timedelta(minutes=10)
246+
# Retry logic should look at this job and calculate retry duration since its last_processed_at.
247+
await create_job(
248+
session=session,
249+
run=run,
250+
status=JobStatus.FAILED,
251+
termination_reason=JobTerminationReason.EXECUTOR_ERROR,
252+
last_processed_at=now - datetime.timedelta(minutes=4),
253+
replica_num=0,
254+
job_provisioning_data=get_job_provisioning_data(),
255+
)
256+
await create_job(
257+
session=session,
258+
run=run,
259+
status=JobStatus.FAILED,
260+
termination_reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY,
261+
replica_num=0,
262+
submission_num=1,
263+
last_processed_at=now - datetime.timedelta(minutes=2),
264+
job_provisioning_data=None,
265+
)
266+
with patch("dstack._internal.utils.common.get_current_datetime") as datetime_mock:
267+
datetime_mock.return_value = now
268+
await process_runs.process_runs()
269+
await session.refresh(run)
270+
assert run.status == RunStatus.PENDING
271+
233272
@pytest.mark.asyncio
234273
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
235274
async def test_pending_to_submitted(self, test_db, session: AsyncSession):

0 commit comments

Comments
 (0)