diff --git a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py index bb8fb85f8..343afad0b 100644 --- a/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py +++ b/airbyte_cdk/sources/declarative/async_job/job_orchestrator.py @@ -179,7 +179,7 @@ def __init__( self._non_breaking_exceptions: List[Exception] = [] def _replace_failed_jobs(self, partition: AsyncPartition) -> None: - failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT) + failed_status_jobs = (AsyncJobStatus.FAILED,) jobs_to_replace = [job for job in partition.jobs if job.status() in failed_status_jobs] for job in jobs_to_replace: new_job = self._start_job(job.job_parameters(), job.api_job_id()) @@ -359,14 +359,11 @@ def _process_running_partitions_and_yield_completed_ones( self._process_partitions_with_errors(partition) case _: self._stop_timed_out_jobs(partition) + # re-allocate FAILED jobs, but TIMEOUT jobs are not re-allocated + self._reallocate_partition(current_running_partitions, partition) - # job will be restarted in `_start_job` - current_running_partitions.insert(0, partition) - - for job in partition.jobs: - # We only remove completed jobs as we want failed/timed out jobs to be re-allocated in priority - if job.status() == AsyncJobStatus.COMPLETED: - self._job_tracker.remove_job(job.api_job_id()) + # We only remove completed / timeout jobs jobs as we want failed jobs to be re-allocated in priority + self._remove_completed_or_timed_out_jobs(partition) # update the referenced list with running partitions self._running_partitions = current_running_partitions @@ -381,8 +378,11 @@ def _stop_partition(self, partition: AsyncPartition) -> None: def _stop_timed_out_jobs(self, partition: AsyncPartition) -> None: for job in partition.jobs: if job.status() == AsyncJobStatus.TIMED_OUT: - # we don't free allocation here because it is expected to retry the job - self._abort_job(job, free_job_allocation=False) + self._abort_job(job, free_job_allocation=True) + raise AirbyteTracedException( + internal_message=f"Job {job.api_job_id()} has timed out. Try increasing the `polling job timeout`.", + failure_type=FailureType.config_error, + ) def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None: try: @@ -392,6 +392,34 @@ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None: except Exception as exception: LOGGER.warning(f"Could not free budget for job {job.api_job_id()}: {exception}") + def _remove_completed_or_timed_out_jobs(self, partition: AsyncPartition) -> None: + """ + Remove completed or timed out jobs from the partition. + + Args: + partition (AsyncPartition): The partition to process. + """ + for job in partition.jobs: + if job.status() in [AsyncJobStatus.COMPLETED, AsyncJobStatus.TIMED_OUT]: + self._job_tracker.remove_job(job.api_job_id()) + + def _reallocate_partition( + self, + current_running_partitions: List[AsyncPartition], + partition: AsyncPartition, + ) -> None: + """ + Reallocate the partition by starting a new job for each job in the + partition. + Args: + current_running_partitions (list): The list of currently running partitions. + partition (AsyncPartition): The partition to reallocate. + """ + for job in partition.jobs: + if job.status() != AsyncJobStatus.TIMED_OUT: + # allow the FAILED jobs to be re-allocated for partition + current_running_partitions.insert(0, partition) + def _process_partitions_with_errors(self, partition: AsyncPartition) -> None: """ Process a partition with status errors (FAILED and TIMEOUT). diff --git a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py index b06d82f5f..e8bca6cc9 100644 --- a/airbyte_cdk/sources/declarative/requesters/http_job_repository.py +++ b/airbyte_cdk/sources/declarative/requesters/http_job_repository.py @@ -273,12 +273,59 @@ def _clean_up_job(self, job_id: str) -> None: del self._create_job_response_by_id[job_id] del self._polling_job_response_by_id[job_id] + def _get_creation_response_interpolation_context(self, job: AsyncJob) -> Dict[str, Any]: + """ + Returns the interpolation context for the creation response. + + Args: + job (AsyncJob): The job for which to get the creation response interpolation context. + + Returns: + Dict[str, Any]: The interpolation context as a dictionary. + """ + # TODO: currently we support only JsonDecoder to decode the response to track the ids or the status + # of the Jobs. We should consider to add the support of other decoders like XMLDecoder, in the future + creation_response_context = dict(self._create_job_response_by_id[job.api_job_id()].json()) + if not "headers" in creation_response_context: + creation_response_context["headers"] = self._create_job_response_by_id[ + job.api_job_id() + ].headers + if not "request" in creation_response_context: + creation_response_context["request"] = self._create_job_response_by_id[ + job.api_job_id() + ].request + return creation_response_context + + def _get_polling_response_interpolation_context(self, job: AsyncJob) -> Dict[str, Any]: + """ + Returns the interpolation context for the polling response. + + Args: + job (AsyncJob): The job for which to get the polling response interpolation context. + + Returns: + Dict[str, Any]: The interpolation context as a dictionary. + """ + # TODO: currently we support only JsonDecoder to decode the response to track the ids or the status + # of the Jobs. We should consider to add the support of other decoders like XMLDecoder, in the future + polling_response_context = dict(self._polling_job_response_by_id[job.api_job_id()].json()) + if not "headers" in polling_response_context: + polling_response_context["headers"] = self._polling_job_response_by_id[ + job.api_job_id() + ].headers + if not "request" in polling_response_context: + polling_response_context["request"] = self._polling_job_response_by_id[ + job.api_job_id() + ].request + return polling_response_context + def _get_create_job_stream_slice(self, job: AsyncJob) -> StreamSlice: - creation_response = self._create_job_response_by_id[job.api_job_id()].json() stream_slice = StreamSlice( partition={}, cursor_slice={}, - extra_fields={"creation_response": creation_response}, + extra_fields={ + "creation_response": self._get_creation_response_interpolation_context(job), + }, ) return stream_slice @@ -286,11 +333,12 @@ def _get_download_targets(self, job: AsyncJob) -> Iterable[str]: if not self.download_target_requester: url_response = self._polling_job_response_by_id[job.api_job_id()] else: - polling_response = self._polling_job_response_by_id[job.api_job_id()].json() stream_slice: StreamSlice = StreamSlice( partition={}, cursor_slice={}, - extra_fields={"polling_response": polling_response}, + extra_fields={ + "polling_response": self._get_polling_response_interpolation_context(job), + }, ) url_response = self.download_target_requester.send_request(stream_slice=stream_slice) # type: ignore # we expect download_target_requester to always be presented, otherwise raise an exception as we cannot proceed with the report if not url_response: diff --git a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py index dc81eacbc..56bbf5349 100644 --- a/unit_tests/sources/declarative/async_job/test_job_orchestrator.py +++ b/unit_tests/sources/declarative/async_job/test_job_orchestrator.py @@ -144,13 +144,10 @@ def test_given_timeout_when_create_and_get_completed_partitions_then_free_budget ) orchestrator = self._orchestrator([_A_STREAM_SLICE], job_tracker) - with pytest.raises(AirbyteTracedException): + with pytest.raises(AirbyteTracedException) as error: list(orchestrator.create_and_get_completed_partitions()) - assert job_tracker.try_to_get_intent() - assert ( - self._job_repository.start.call_args_list - == [call(_A_STREAM_SLICE)] * _MAX_NUMBER_OF_ATTEMPTS - ) + + assert "Job an api job id has timed out" in str(error.value) @mock.patch(sleep_mock_target) def test_given_failure_when_create_and_get_completed_partitions_then_raise_exception(