22import datetime
33from typing import List , Optional , Set , Tuple
44
5- from sqlalchemy import and_ , or_ , select
5+ from sqlalchemy import and_ , func , or_ , select
66from sqlalchemy .ext .asyncio import AsyncSession
7- from sqlalchemy .orm import joinedload , load_only , selectinload
7+ from sqlalchemy .orm import aliased , contains_eager , joinedload , load_only
88
99import dstack ._internal .server .services .services .autoscalers as autoscalers
1010from dstack ._internal .core .errors import ServerError
3333 get_job_specs_from_run_spec ,
3434 group_jobs_by_replica_latest ,
3535 is_master_job ,
36+ job_model_to_job_submission ,
3637 switch_job_status ,
3738)
3839from dstack ._internal .server .services .locking import get_locker
@@ -144,22 +145,7 @@ async def _process_next_run():
144145
145146
146147async def _process_run (session : AsyncSession , run_model : RunModel ):
147- # Refetch to load related attributes.
148- res = await session .execute (
149- select (RunModel )
150- .where (RunModel .id == run_model .id )
151- .execution_options (populate_existing = True )
152- .options (joinedload (RunModel .project ).load_only (ProjectModel .id , ProjectModel .name ))
153- .options (joinedload (RunModel .user ).load_only (UserModel .name ))
154- .options (joinedload (RunModel .fleet ).load_only (FleetModel .id , FleetModel .name ))
155- .options (
156- selectinload (RunModel .jobs )
157- .joinedload (JobModel .instance )
158- .load_only (InstanceModel .fleet_id )
159- )
160- .execution_options (populate_existing = True )
161- )
162- run_model = res .unique ().scalar_one ()
148+ run_model = await _refetch_run_model (session , run_model )
163149 logger .debug ("%s: processing run" , fmt (run_model ))
164150 try :
165151 if run_model .status == RunStatus .PENDING :
@@ -181,6 +167,46 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
181167 await session .commit ()
182168
183169
170+ async def _refetch_run_model (session : AsyncSession , run_model : RunModel ) -> RunModel :
171+ # Select only latest submissions for every job.
172+ latest_submissions_sq = (
173+ select (
174+ JobModel .run_id .label ("run_id" ),
175+ JobModel .replica_num .label ("replica_num" ),
176+ JobModel .job_num .label ("job_num" ),
177+ func .max (JobModel .submission_num ).label ("max_submission_num" ),
178+ )
179+ .where (JobModel .run_id == run_model .id )
180+ .group_by (JobModel .run_id , JobModel .replica_num , JobModel .job_num )
181+ .subquery ()
182+ )
183+ job_alias = aliased (JobModel )
184+ res = await session .execute (
185+ select (RunModel )
186+ .where (RunModel .id == run_model .id )
187+ .outerjoin (latest_submissions_sq , latest_submissions_sq .c .run_id == RunModel .id )
188+ .outerjoin (
189+ job_alias ,
190+ onclause = and_ (
191+ job_alias .run_id == latest_submissions_sq .c .run_id ,
192+ job_alias .replica_num == latest_submissions_sq .c .replica_num ,
193+ job_alias .job_num == latest_submissions_sq .c .job_num ,
194+ job_alias .submission_num == latest_submissions_sq .c .max_submission_num ,
195+ ),
196+ )
197+ .options (joinedload (RunModel .project ).load_only (ProjectModel .id , ProjectModel .name ))
198+ .options (joinedload (RunModel .user ).load_only (UserModel .name ))
199+ .options (joinedload (RunModel .fleet ).load_only (FleetModel .id , FleetModel .name ))
200+ .options (
201+ contains_eager (RunModel .jobs , alias = job_alias )
202+ .joinedload (JobModel .instance )
203+ .load_only (InstanceModel .fleet_id )
204+ )
205+ .execution_options (populate_existing = True )
206+ )
207+ return res .unique ().scalar_one ()
208+
209+
184210async def _process_pending_run (session : AsyncSession , run_model : RunModel ):
185211 """Jobs are not created yet"""
186212 run = run_model_to_run (run_model )
@@ -294,7 +320,7 @@ async def _process_active_run(session: AsyncSession, run_model: RunModel):
294320 and job_model .termination_reason
295321 not in {JobTerminationReason .DONE_BY_RUNNER , JobTerminationReason .SCALED_DOWN }
296322 ):
297- current_duration = _should_retry_job (run , job , job_model )
323+ current_duration = await _should_retry_job (session , run , job , job_model )
298324 if current_duration is None :
299325 replica_statuses .add (RunStatus .FAILED )
300326 run_termination_reasons .add (RunTerminationReason .JOB_FAILED )
@@ -552,19 +578,44 @@ def _has_out_of_date_replicas(run: RunModel) -> bool:
552578 return False
553579
554580
555- def _should_retry_job (run : Run , job : Job , job_model : JobModel ) -> Optional [datetime .timedelta ]:
581+ async def _should_retry_job (
582+ session : AsyncSession ,
583+ run : Run ,
584+ job : Job ,
585+ job_model : JobModel ,
586+ ) -> Optional [datetime .timedelta ]:
556587 """
557588 Checks if the job should be retried.
558589 Returns the current duration of retrying if retry is enabled.
590+ Retrying duration is calculated as the time since `last_processed_at`
591+ of the latest provisioned submission.
559592 """
560593 if job .job_spec .retry is None :
561594 return None
562595
563596 last_provisioned_submission = None
564- for job_submission in reversed (job .job_submissions ):
565- if job_submission .job_provisioning_data is not None :
566- last_provisioned_submission = job_submission
567- break
597+ if len (job .job_submissions ) > 0 :
598+ last_submission = job .job_submissions [- 1 ]
599+ if last_submission .job_provisioning_data is not None :
600+ last_provisioned_submission = last_submission
601+ else :
602+ # The caller passes at most one latest submission in job.job_submissions, so check the db.
603+ res = await session .execute (
604+ select (JobModel )
605+ .where (
606+ JobModel .run_id == job_model .run_id ,
607+ JobModel .replica_num == job_model .replica_num ,
608+ JobModel .job_num == job_model .job_num ,
609+ JobModel .job_provisioning_data .is_not (None ),
610+ )
611+ .order_by (JobModel .last_processed_at .desc ())
612+ .limit (1 )
613+ )
614+ last_provisioned_submission_model = res .scalar ()
615+ if last_provisioned_submission_model is not None :
616+ last_provisioned_submission = job_model_to_job_submission (
617+ last_provisioned_submission_model
618+ )
568619
569620 if (
570621 job_model .termination_reason is not None
@@ -574,13 +625,10 @@ def _should_retry_job(run: Run, job: Job, job_model: JobModel) -> Optional[datet
574625 ):
575626 return common .get_current_datetime () - run .submitted_at
576627
577- if last_provisioned_submission is None :
578- return None
579-
580628 if (
581- last_provisioned_submission .termination_reason is not None
582- and JobTerminationReason ( last_provisioned_submission .termination_reason ) .to_retry_event ()
583- in job . job_spec . retry . on_events
629+ job_model .termination_reason is not None
630+ and job_model .termination_reason .to_retry_event () in job . job_spec . retry . on_events
631+ and last_provisioned_submission is not None
584632 ):
585633 return common .get_current_datetime () - last_provisioned_submission .last_processed_at
586634
0 commit comments