diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 0f9f3407fc..3be1fe0d69 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -828,6 +828,9 @@ def map_execute_task_cmd( prev_checkpoint, checkpoint_path, ): + logger.info("Registering faulthandler for SIGUSR1 for map tasks") + faulthandler.register(signal.SIGUSR1) + logger.info(get_version_message()) raw_output_data_prefix, checkpoint_path, prev_checkpoint = normalize_inputs( diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 0019e4d79b..1ea64daded 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -459,7 +459,7 @@ def __init__(self, id, spec, closure): self._closure = closure @property - def id(self): + def id(self) -> _identifier.WorkflowExecutionIdentifier: """ :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier """ @@ -532,7 +532,7 @@ def __init__( phase: int, started_at: datetime.datetime, duration: datetime.timedelta, - error: typing.Optional[flytekit.models.core.execution.ExecutionError] = None, + error: typing.Optional[_core_execution.ExecutionError] = None, outputs: typing.Optional[LiteralMapBlob] = None, abort_metadata: typing.Optional[AbortMetadata] = None, created_at: typing.Optional[datetime.datetime] = None, @@ -556,7 +556,7 @@ def __init__( self._updated_at = updated_at @property - def error(self) -> flytekit.models.core.execution.ExecutionError: + def error(self) -> _core_execution.ExecutionError: return self._error @property diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 1af159108a..f41beba33c 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2474,6 +2474,10 @@ def wait( timedelta or a duration in seconds as int. :param sync_nodes: passed along to the sync call for the workflow execution """ + logger.debug( + f"Beginning wait for {execution.id} with {timeout=}, {poll_interval=}, and {sync_nodes=}." + ) + if poll_interval is not None and not isinstance(poll_interval, timedelta): poll_interval = timedelta(seconds=poll_interval) poll_interval = poll_interval or timedelta(seconds=30) @@ -2482,11 +2486,31 @@ def wait( timeout = timedelta(seconds=timeout) time_to_give_up = datetime.max if timeout is None else datetime.now() + timeout + poll_count = 0 while datetime.now() < time_to_give_up: + if poll_count % 10 == 0: + logger.debug(f"Waiting for execution {execution.id} to complete.") + logger.debug(f"Current phase: {execution.closure.phase}, {execution.closure.updated_at=}") + execution = self.sync_execution(execution, sync_nodes=sync_nodes) if execution.is_done: return execution time.sleep(poll_interval.total_seconds()) + poll_count += 1 + + if datetime.now() > time_to_give_up: + logger.info("Wait timeout exceeded. Syncing execution one final time.") + refetched_exec = self.fetch_execution( + project=execution.id.project, + domain=execution.id.domain, + name=execution.id.name) + if refetched_exec.is_done: + logger.info("Re-sync'ed execution found to be complete!") + if sync_nodes: + self.sync_execution(refetched_exec, sync_nodes=True) + return refetched_exec + else: + logger.debug(f"Execution {execution.id} not complete after timeout, phase is {refetched_exec.closure.phase}") raise user_exceptions.FlyteTimeout(f"Execution {self} did not complete before timeout.")