Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions airbyte_cdk/sources/declarative/async_job/job_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
self._non_breaking_exceptions: List[Exception] = []

def _replace_failed_jobs(self, partition: AsyncPartition) -> None:
failed_status_jobs = (AsyncJobStatus.FAILED, AsyncJobStatus.TIMED_OUT)
failed_status_jobs = (AsyncJobStatus.FAILED,)
jobs_to_replace = [job for job in partition.jobs if job.status() in failed_status_jobs]
for job in jobs_to_replace:
new_job = self._start_job(job.job_parameters(), job.api_job_id())
Expand Down Expand Up @@ -359,14 +359,11 @@ def _process_running_partitions_and_yield_completed_ones(
self._process_partitions_with_errors(partition)
case _:
self._stop_timed_out_jobs(partition)
# re-allocate FAILED jobs, but TIMEOUT jobs are not re-allocated
self._reallocate_partition(current_running_partitions, partition)

# job will be restarted in `_start_job`
current_running_partitions.insert(0, partition)

for job in partition.jobs:
# We only remove completed jobs as we want failed/timed out jobs to be re-allocated in priority
if job.status() == AsyncJobStatus.COMPLETED:
self._job_tracker.remove_job(job.api_job_id())
# We only remove completed / timeout jobs jobs as we want failed jobs to be re-allocated in priority
self._remove_completed_or_timed_out_jobs(partition)

# update the referenced list with running partitions
self._running_partitions = current_running_partitions
Expand All @@ -381,8 +378,11 @@ def _stop_partition(self, partition: AsyncPartition) -> None:
def _stop_timed_out_jobs(self, partition: AsyncPartition) -> None:
for job in partition.jobs:
if job.status() == AsyncJobStatus.TIMED_OUT:
# we don't free allocation here because it is expected to retry the job
self._abort_job(job, free_job_allocation=False)
self._abort_job(job, free_job_allocation=True)
raise AirbyteTracedException(
internal_message=f"Job {job.api_job_id()} has timed out. Try increasing the `polling job timeout`.",
failure_type=FailureType.config_error,
)

def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
try:
Expand All @@ -392,6 +392,34 @@ def _abort_job(self, job: AsyncJob, free_job_allocation: bool = True) -> None:
except Exception as exception:
LOGGER.warning(f"Could not free budget for job {job.api_job_id()}: {exception}")

def _remove_completed_or_timed_out_jobs(self, partition: AsyncPartition) -> None:
"""
Remove completed or timed out jobs from the partition.

Args:
partition (AsyncPartition): The partition to process.
"""
for job in partition.jobs:
if job.status() in [AsyncJobStatus.COMPLETED, AsyncJobStatus.TIMED_OUT]:
self._job_tracker.remove_job(job.api_job_id())

def _reallocate_partition(
self,
current_running_partitions: List[AsyncPartition],
partition: AsyncPartition,
) -> None:
"""
Reallocate the partition by starting a new job for each job in the
partition.
Args:
current_running_partitions (list): The list of currently running partitions.
partition (AsyncPartition): The partition to reallocate.
"""
for job in partition.jobs:
if job.status() != AsyncJobStatus.TIMED_OUT:
# allow the FAILED jobs to be re-allocated for partition
current_running_partitions.insert(0, partition)

def _process_partitions_with_errors(self, partition: AsyncPartition) -> None:
"""
Process a partition with status errors (FAILED and TIMEOUT).
Expand Down
56 changes: 52 additions & 4 deletions airbyte_cdk/sources/declarative/requesters/http_job_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,24 +273,72 @@ def _clean_up_job(self, job_id: str) -> None:
del self._create_job_response_by_id[job_id]
del self._polling_job_response_by_id[job_id]

def _get_creation_response_interpolation_context(self, job: AsyncJob) -> Dict[str, Any]:
"""
Returns the interpolation context for the creation response.

Args:
job (AsyncJob): The job for which to get the creation response interpolation context.

Returns:
Dict[str, Any]: The interpolation context as a dictionary.
"""
# TODO: currently we support only JsonDecoder to decode the response to track the ids or the status
# of the Jobs. We should consider to add the support of other decoders like XMLDecoder, in the future
creation_response_context = dict(self._create_job_response_by_id[job.api_job_id()].json())
if not "headers" in creation_response_context:
creation_response_context["headers"] = self._create_job_response_by_id[
job.api_job_id()
].headers
if not "request" in creation_response_context:
creation_response_context["request"] = self._create_job_response_by_id[
job.api_job_id()
].request
return creation_response_context

def _get_polling_response_interpolation_context(self, job: AsyncJob) -> Dict[str, Any]:
"""
Returns the interpolation context for the polling response.

Args:
job (AsyncJob): The job for which to get the polling response interpolation context.

Returns:
Dict[str, Any]: The interpolation context as a dictionary.
"""
# TODO: currently we support only JsonDecoder to decode the response to track the ids or the status
# of the Jobs. We should consider to add the support of other decoders like XMLDecoder, in the future
polling_response_context = dict(self._polling_job_response_by_id[job.api_job_id()].json())
if not "headers" in polling_response_context:
polling_response_context["headers"] = self._polling_job_response_by_id[
job.api_job_id()
].headers
if not "request" in polling_response_context:
polling_response_context["request"] = self._polling_job_response_by_id[
job.api_job_id()
].request
return polling_response_context

def _get_create_job_stream_slice(self, job: AsyncJob) -> StreamSlice:
creation_response = self._create_job_response_by_id[job.api_job_id()].json()
stream_slice = StreamSlice(
partition={},
cursor_slice={},
extra_fields={"creation_response": creation_response},
extra_fields={
"creation_response": self._get_creation_response_interpolation_context(job),
},
)
return stream_slice

def _get_download_targets(self, job: AsyncJob) -> Iterable[str]:
if not self.download_target_requester:
url_response = self._polling_job_response_by_id[job.api_job_id()]
else:
polling_response = self._polling_job_response_by_id[job.api_job_id()].json()
stream_slice: StreamSlice = StreamSlice(
partition={},
cursor_slice={},
extra_fields={"polling_response": polling_response},
extra_fields={
"polling_response": self._get_polling_response_interpolation_context(job),
},
)
url_response = self.download_target_requester.send_request(stream_slice=stream_slice) # type: ignore # we expect download_target_requester to always be presented, otherwise raise an exception as we cannot proceed with the report
if not url_response:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,10 @@ def test_given_timeout_when_create_and_get_completed_partitions_then_free_budget
)
orchestrator = self._orchestrator([_A_STREAM_SLICE], job_tracker)

with pytest.raises(AirbyteTracedException):
with pytest.raises(AirbyteTracedException) as error:
list(orchestrator.create_and_get_completed_partitions())
assert job_tracker.try_to_get_intent()
assert (
self._job_repository.start.call_args_list
== [call(_A_STREAM_SLICE)] * _MAX_NUMBER_OF_ATTEMPTS
)

assert "Job an api job id has timed out" in str(error.value)

@mock.patch(sleep_mock_target)
def test_given_failure_when_create_and_get_completed_partitions_then_raise_exception(
Expand Down
Loading