@@ -44,16 +44,21 @@ class AsyncPartition:
4444 This bucket of api_jobs is a bit useless for this iteration but should become interesting when we will be able to split jobs
4545 """
4646
47- _MAX_NUMBER_OF_ATTEMPTS = 3
47+ _DEFAULT_MAX_JOB_RETRY = 3
4848
49- def __init__ (self , jobs : List [AsyncJob ], stream_slice : StreamSlice ) -> None :
49+ def __init__ (
50+ self , jobs : List [AsyncJob ], stream_slice : StreamSlice , job_max_retry : Optional [int ] = None
51+ ) -> None :
5052 self ._attempts_per_job = {job : 1 for job in jobs }
5153 self ._stream_slice = stream_slice
54+ self ._job_max_retry = (
55+ job_max_retry if job_max_retry is not None else self ._DEFAULT_MAX_JOB_RETRY
56+ )
5257
5358 def has_reached_max_attempt (self ) -> bool :
5459 return any (
5560 map (
56- lambda attempt_count : attempt_count >= self ._MAX_NUMBER_OF_ATTEMPTS ,
61+ lambda attempt_count : attempt_count >= self ._job_max_retry ,
5762 self ._attempts_per_job .values (),
5863 )
5964 )
@@ -62,7 +67,7 @@ def replace_job(self, job_to_replace: AsyncJob, new_jobs: List[AsyncJob]) -> Non
6267 current_attempt_count = self ._attempts_per_job .pop (job_to_replace , None )
6368 if current_attempt_count is None :
6469 raise ValueError ("Could not find job to replace" )
65- elif current_attempt_count >= self ._MAX_NUMBER_OF_ATTEMPTS :
70+ elif current_attempt_count >= self ._job_max_retry :
6671 raise ValueError (f"Max attempt reached for job in partition { self ._stream_slice } " )
6772
6873 new_attempt_count = current_attempt_count + 1
@@ -155,6 +160,7 @@ def __init__(
155160 message_repository : MessageRepository ,
156161 exceptions_to_break_on : Iterable [Type [Exception ]] = tuple (),
157162 has_bulk_parent : bool = False ,
163+ job_max_retry : Optional [int ] = None ,
158164 ) -> None :
159165 """
160166 If the stream slices provided as a parameters relies on a async job streams that relies on the same JobTracker, `has_bulk_parent`
@@ -175,6 +181,7 @@ def __init__(
175181 self ._message_repository = message_repository
176182 self ._exceptions_to_break_on : Tuple [Type [Exception ], ...] = tuple (exceptions_to_break_on )
177183 self ._has_bulk_parent = has_bulk_parent
184+ self ._job_max_retry = job_max_retry
178185
179186 self ._non_breaking_exceptions : List [Exception ] = []
180187
@@ -214,7 +221,7 @@ def _start_jobs(self) -> None:
214221 for _slice in self ._slice_iterator :
215222 at_least_one_slice_consumed_from_slice_iterator_during_current_iteration = True
216223 job = self ._start_job (_slice )
217- self ._running_partitions .append (AsyncPartition ([job ], _slice ))
224+ self ._running_partitions .append (AsyncPartition ([job ], _slice , self . _job_max_retry ))
218225 if self ._has_bulk_parent and self ._slice_iterator .has_next ():
219226 break
220227 except ConcurrentJobLimitReached :
@@ -359,14 +366,11 @@ def _process_running_partitions_and_yield_completed_ones(
359366 self ._process_partitions_with_errors (partition )
360367 case _:
361368 self ._stop_timed_out_jobs (partition )
369+ # re-allocate FAILED jobs, but TIMEOUT jobs are not re-allocated
370+ self ._reallocate_partition (current_running_partitions , partition )
362371
363- # job will be restarted in `_start_job`
364- current_running_partitions .insert (0 , partition )
365-
366- for job in partition .jobs :
367- # We only remove completed jobs as we want failed/timed out jobs to be re-allocated in priority
368- if job .status () == AsyncJobStatus .COMPLETED :
369- self ._job_tracker .remove_job (job .api_job_id ())
372+ # We only remove completed / timeout jobs jobs as we want failed jobs to be re-allocated in priority
373+ self ._remove_completed_jobs (partition )
370374
371375 # update the referenced list with running partitions
372376 self ._running_partitions = current_running_partitions
@@ -381,7 +385,6 @@ def _stop_partition(self, partition: AsyncPartition) -> None:
381385 def _stop_timed_out_jobs (self , partition : AsyncPartition ) -> None :
382386 for job in partition .jobs :
383387 if job .status () == AsyncJobStatus .TIMED_OUT :
384- # we don't free allocation here because it is expected to retry the job
385388 self ._abort_job (job , free_job_allocation = False )
386389
387390 def _abort_job (self , job : AsyncJob , free_job_allocation : bool = True ) -> None :
@@ -392,6 +395,31 @@ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
392395 except Exception as exception :
393396 LOGGER .warning (f"Could not free budget for job { job .api_job_id ()} : { exception } " )
394397
398+ def _remove_completed_jobs (self , partition : AsyncPartition ) -> None :
399+ """
400+ Remove completed or timed out jobs from the partition.
401+
402+ Args:
403+ partition (AsyncPartition): The partition to process.
404+ """
405+ for job in partition .jobs :
406+ if job .status () == AsyncJobStatus .COMPLETED :
407+ self ._job_tracker .remove_job (job .api_job_id ())
408+
409+ def _reallocate_partition (
410+ self ,
411+ current_running_partitions : List [AsyncPartition ],
412+ partition : AsyncPartition ,
413+ ) -> None :
414+ """
415+ Reallocate the partition by starting a new job for each job in the
416+ partition.
417+ Args:
418+ current_running_partitions (list): The list of currently running partitions.
419+ partition (AsyncPartition): The partition to reallocate.
420+ """
421+ current_running_partitions .insert (0 , partition )
422+
395423 def _process_partitions_with_errors (self , partition : AsyncPartition ) -> None :
396424 """
397425 Process a partition with status errors (FAILED and TIMEOUT).
0 commit comments