diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index 79027aeba0..a82d84502f 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -1135,19 +1135,14 @@ async def _handle_request_retries( except Exception as e: raise UserDefinedErrorHandlerError('Exception thrown in user-defined request error handler') from e else: - if new_request is not None: - request = new_request + if new_request is not None and new_request != request: + await request_manager.add_request(new_request) + await self._mark_request_as_handled(request) + return await request_manager.reclaim_request(request) else: - await wait_for( - lambda: request_manager.mark_request_as_handled(context.request), - timeout=self._internal_timeout, - timeout_message='Marking request as handled timed out after ' - f'{self._internal_timeout.total_seconds()} seconds', - logger=self._logger, - max_retries=3, - ) + await self._mark_request_as_handled(request) await self._handle_failed_request(context, error) self._statistics.record_request_processing_failure(request.unique_key) @@ -1196,16 +1191,7 @@ async def _handle_skipped_request( self, request: Request | str, reason: SkippedReason, *, need_mark: bool = False ) -> None: if need_mark and isinstance(request, Request): - request_manager = await self.get_request_manager() - - await wait_for( - lambda: request_manager.mark_request_as_handled(request), - timeout=self._internal_timeout, - timeout_message='Marking request as handled timed out after ' - f'{self._internal_timeout.total_seconds()} seconds', - logger=self._logger, - max_retries=3, - ) + await self._mark_request_as_handled(request) request.state = RequestState.SKIPPED url = request.url if isinstance(request, Request) else request @@ -1417,14 +1403,8 @@ async def __run_task_function(self) -> None: raise RequestHandlerError(e, context) from e await self._commit_request_handler_result(context) - await wait_for( - lambda: request_manager.mark_request_as_handled(context.request), - timeout=self._internal_timeout, - timeout_message='Marking request as handled timed out after ' - f'{self._internal_timeout.total_seconds()} seconds', - logger=self._logger, - max_retries=3, - ) + + await self._mark_request_as_handled(request) request.state = RequestState.DONE @@ -1467,14 +1447,7 @@ async def __run_task_function(self) -> None: await request_manager.reclaim_request(request) await self._statistics.error_tracker_retry.add(error=session_error, context=context) else: - await wait_for( - lambda: request_manager.mark_request_as_handled(context.request), - timeout=self._internal_timeout, - timeout_message='Marking request as handled timed out after ' - f'{self._internal_timeout.total_seconds()} seconds', - logger=self._logger, - max_retries=3, - ) + await self._mark_request_as_handled(request) await self._handle_failed_request(context, session_error) self._statistics.record_request_processing_failure(request.unique_key) @@ -1482,14 +1455,7 @@ async def __run_task_function(self) -> None: except ContextPipelineInterruptedError as interrupted_error: self._logger.debug('The context pipeline was interrupted', exc_info=interrupted_error) - await wait_for( - lambda: request_manager.mark_request_as_handled(context.request), - timeout=self._internal_timeout, - timeout_message='Marking request as handled timed out after ' - f'{self._internal_timeout.total_seconds()} seconds', - logger=self._logger, - max_retries=3, - ) + await self._mark_request_as_handled(request) except ContextPipelineInitializationError as initialization_error: self._logger.debug( @@ -1660,3 +1626,14 @@ async def _crawler_state_task(self) -> None: ) self._previous_crawler_state = current_state + + async def _mark_request_as_handled(self, request: Request) -> None: + request_manager = await self.get_request_manager() + await wait_for( + lambda: request_manager.mark_request_as_handled(request), + timeout=self._internal_timeout, + timeout_message='Marking request as handled timed out after ' + f'{self._internal_timeout.total_seconds()} seconds', + logger=self._logger, + max_retries=3, + ) diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 070267a20a..a3c43acf40 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -18,7 +18,7 @@ import pytest from crawlee import ConcurrencySettings, Glob, service_locator -from crawlee._request import Request +from crawlee._request import Request, RequestState from crawlee._types import BasicCrawlingContext, EnqueueLinksKwargs, HttpMethod from crawlee._utils.robots import RobotsTxtFile from crawlee.configuration import Configuration @@ -1768,3 +1768,39 @@ async def handler(_: BasicCrawlingContext) -> None: # Wait for crawler to finish await crawler_task + + +async def test_new_request_error_handler() -> None: + """Test that error in new_request_handler is handled properly.""" + queue = await RequestQueue.open() + crawler = BasicCrawler( + request_manager=queue, + ) + + request = Request.from_url('https://a.placeholder.com') + + @crawler.router.default_handler + async def handler(context: BasicCrawlingContext) -> None: + if '|test' in context.request.unique_key: + return + raise ValueError('This error should not be handled by error handler') + + @crawler.error_handler + async def error_handler(context: BasicCrawlingContext, error: Exception) -> Request | None: + return Request.from_url( + context.request.url, + unique_key=f'{context.request.unique_key}|test', + ) + + await crawler.run([request]) + + original_request = await queue.get_request(request.unique_key) + error_request = await queue.get_request(f'{request.unique_key}|test') + + assert original_request is not None + assert original_request.state == RequestState.ERROR_HANDLER + assert original_request.was_already_handled + + assert error_request is not None + assert error_request.state == RequestState.REQUEST_HANDLER + assert error_request.was_already_handled