Skip to content

Commit a69caf8

Browse files
authored
fix: Protect Request from partial mutations on request handler failure (#1585)
### Description - Protect `Request` from partial mutations on request handler failure relevant with `AdaptivePlaywrightCrawler` ### Issues - Closes: #1514
1 parent cad5219 commit a69caf8

File tree

6 files changed

+96
-44
lines changed

6 files changed

+96
-44
lines changed

src/crawlee/_types.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import dataclasses
44
from collections.abc import Callable, Iterator, Mapping
5+
from copy import deepcopy
56
from dataclasses import dataclass
67
from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypedDict, TypeVar, cast, overload
78

@@ -260,12 +261,24 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None:
260261
class RequestHandlerRunResult:
261262
"""Record of calls to storage-related context helpers."""
262263

263-
def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None:
264+
def __init__(
265+
self,
266+
*,
267+
key_value_store_getter: GetKeyValueStoreFunction,
268+
request: Request,
269+
) -> None:
264270
self._key_value_store_getter = key_value_store_getter
265271
self.add_requests_calls = list[AddRequestsKwargs]()
266272
self.push_data_calls = list[PushDataFunctionCall]()
267273
self.key_value_store_changes = dict[tuple[str | None, str | None, str | None], KeyValueStoreChangeRecords]()
268274

275+
# Isolated copies for handler execution
276+
self._request = deepcopy(request)
277+
278+
@property
279+
def request(self) -> Request:
280+
return self._request
281+
269282
async def add_requests(
270283
self,
271284
requests: Sequence[str | Request],
@@ -315,6 +328,14 @@ async def get_key_value_store(
315328

316329
return self.key_value_store_changes[id, name, alias]
317330

331+
def apply_request_changes(self, target: Request) -> None:
332+
"""Apply tracked changes from handler copy to original request."""
333+
if self.request.user_data != target.user_data:
334+
target.user_data = self.request.user_data
335+
336+
if self.request.headers != target.headers:
337+
target.headers = self.request.headers
338+
318339

319340
@docs_group('Functions')
320341
class AddRequestsFunction(Protocol):

src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,14 @@ async def get_input_state(
290290
use_state_function = context.use_state
291291

292292
# New result is created and injected to newly created context. This is done to ensure isolation of sub crawlers.
293-
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
293+
result = RequestHandlerRunResult(
294+
key_value_store_getter=self.get_key_value_store,
295+
request=context.request,
296+
)
294297
context_linked_to_result = BasicCrawlingContext(
295-
request=deepcopy(context.request),
296-
session=deepcopy(context.session),
297-
proxy_info=deepcopy(context.proxy_info),
298+
request=result.request,
299+
session=context.session,
300+
proxy_info=context.proxy_info,
298301
send_request=context.send_request,
299302
add_requests=result.add_requests,
300303
push_data=result.push_data,
@@ -314,7 +317,7 @@ async def get_input_state(
314317
),
315318
logger=self._logger,
316319
)
317-
return SubCrawlerRun(result=result, run_context=context_linked_to_result)
320+
return SubCrawlerRun(result=result)
318321
except Exception as e:
319322
return SubCrawlerRun(exception=e)
320323

@@ -370,8 +373,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
370373
self.track_http_only_request_handler_runs()
371374

372375
static_run = await self._crawl_one(rendering_type='static', context=context)
373-
if static_run.result and static_run.run_context and self.result_checker(static_run.result):
374-
self._update_context_from_copy(context, static_run.run_context)
376+
if static_run.result and self.result_checker(static_run.result):
375377
self._context_result_map[context] = static_run.result
376378
return
377379
if static_run.exception:
@@ -402,7 +404,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
402404
if pw_run.exception is not None:
403405
raise pw_run.exception
404406

405-
if pw_run.result and pw_run.run_context:
407+
if pw_run.result:
406408
if should_detect_rendering_type:
407409
detection_result: RenderingType
408410
static_run = await self._crawl_one('static', context=context, state=old_state_copy)
@@ -414,7 +416,6 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
414416
context.log.debug(f'Detected rendering type {detection_result} for {context.request.url}')
415417
self.rendering_type_predictor.store_result(context.request, detection_result)
416418

417-
self._update_context_from_copy(context, pw_run.run_context)
418419
self._context_result_map[context] = pw_run.result
419420

420421
def pre_navigation_hook(
@@ -451,32 +452,8 @@ def track_browser_request_handler_runs(self) -> None:
451452
def track_rendering_type_mispredictions(self) -> None:
452453
self.statistics.state.rendering_type_mispredictions += 1
453454

454-
def _update_context_from_copy(self, context: BasicCrawlingContext, context_copy: BasicCrawlingContext) -> None:
455-
"""Update mutable fields of `context` from `context_copy`.
456-
457-
Uses object.__setattr__ to bypass frozen dataclass restrictions,
458-
allowing state synchronization after isolated crawler execution.
459-
"""
460-
updating_attributes = {
461-
'request': ('headers', 'user_data'),
462-
'session': ('_user_data', '_usage_count', '_error_score', '_cookies'),
463-
}
464-
465-
for attr, sub_attrs in updating_attributes.items():
466-
original_sub_obj = getattr(context, attr)
467-
copy_sub_obj = getattr(context_copy, attr)
468-
469-
# Check that both sub objects are not None
470-
if original_sub_obj is None or copy_sub_obj is None:
471-
continue
472-
473-
for sub_attr in sub_attrs:
474-
new_value = getattr(copy_sub_obj, sub_attr)
475-
object.__setattr__(original_sub_obj, sub_attr, new_value)
476-
477455

478456
@dataclass(frozen=True)
479457
class SubCrawlerRun:
480458
result: RequestHandlerRunResult | None = None
481459
exception: Exception | None = None
482-
run_context: BasicCrawlingContext | None = None

src/crawlee/crawlers/_basic/_basic_crawler.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from crawlee.storages import Dataset, KeyValueStore, RequestQueue
7070

7171
from ._context_pipeline import ContextPipeline
72+
from ._context_utils import swaped_context
7273
from ._logging_utils import (
7374
get_one_line_error_summary_if_possible,
7475
reduce_asyncio_timeout_error_to_relevant_traceback_parts,
@@ -1326,6 +1327,8 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) ->
13261327

13271328
await self._commit_key_value_store_changes(result, get_kvs=self.get_key_value_store)
13281329

1330+
result.apply_request_changes(target=context.request)
1331+
13291332
@staticmethod
13301333
async def _commit_key_value_store_changes(
13311334
result: RequestHandlerRunResult, get_kvs: GetKeyValueStoreFromRequestHandlerFunction
@@ -1391,10 +1394,10 @@ async def __run_task_function(self) -> None:
13911394
else:
13921395
session = await self._get_session()
13931396
proxy_info = await self._get_proxy_info(request, session)
1394-
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
1397+
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store, request=request)
13951398

13961399
context = BasicCrawlingContext(
1397-
request=request,
1400+
request=result.request,
13981401
session=session,
13991402
proxy_info=proxy_info,
14001403
send_request=self._prepare_send_request_function(session, proxy_info),
@@ -1409,10 +1412,12 @@ async def __run_task_function(self) -> None:
14091412
self._statistics.record_request_processing_start(request.unique_key)
14101413

14111414
try:
1412-
self._check_request_collision(context.request, context.session)
1415+
request.state = RequestState.REQUEST_HANDLER
14131416

14141417
try:
1415-
await self._run_request_handler(context=context)
1418+
with swaped_context(context, request):
1419+
self._check_request_collision(request, session)
1420+
await self._run_request_handler(context=context)
14161421
except asyncio.TimeoutError as e:
14171422
raise RequestHandlerError(e, context) from e
14181423

@@ -1422,13 +1427,13 @@ async def __run_task_function(self) -> None:
14221427

14231428
await self._mark_request_as_handled(request)
14241429

1425-
if context.session and context.session.is_usable:
1426-
context.session.mark_good()
1430+
if session and session.is_usable:
1431+
session.mark_good()
14271432

14281433
self._statistics.record_request_processing_finish(request.unique_key)
14291434

14301435
except RequestCollisionError as request_error:
1431-
context.request.no_retry = True
1436+
request.no_retry = True
14321437
await self._handle_request_error(context, request_error)
14331438

14341439
except RequestHandlerError as primary_error:
@@ -1443,7 +1448,7 @@ async def __run_task_function(self) -> None:
14431448
await self._handle_request_error(primary_error.crawling_context, primary_error.wrapped_exception)
14441449

14451450
except SessionError as session_error:
1446-
if not context.session:
1451+
if not session:
14471452
raise RuntimeError('SessionError raised in a crawling context without a session') from session_error
14481453

14491454
if self._error_handler:
@@ -1453,10 +1458,11 @@ async def __run_task_function(self) -> None:
14531458
exc_only = ''.join(traceback.format_exception_only(session_error)).strip()
14541459
self._logger.warning('Encountered "%s", rotating session and retrying...', exc_only)
14551460

1456-
context.session.retire()
1461+
if session:
1462+
session.retire()
14571463

14581464
# Increment session rotation count.
1459-
context.request.session_rotation_count = (context.request.session_rotation_count or 0) + 1
1465+
request.session_rotation_count = (request.session_rotation_count or 0) + 1
14601466

14611467
await request_manager.reclaim_request(request)
14621468
await self._statistics.error_tracker_retry.add(error=session_error, context=context)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from collections.abc import Iterator
8+
9+
from crawlee._request import Request
10+
11+
from ._basic_crawling_context import BasicCrawlingContext
12+
13+
14+
@contextmanager
15+
def swaped_context(
16+
context: BasicCrawlingContext,
17+
request: Request,
18+
) -> Iterator[None]:
19+
"""Replace context's isolated copies with originals after handler execution."""
20+
try:
21+
yield
22+
finally:
23+
# Restore original context state to avoid side effects between different handlers.
24+
object.__setattr__(context, 'request', request)

tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
802802

803803
assert session is not None
804804
assert check_request is not None
805+
805806
assert session.user_data.get('session_state') is True
806807
# Check that request user data was updated in the handler and only onse.
807808
assert check_request.user_data.get('request_state') == ['initial', 'handler']

tests/unit/crawlers/_basic/test_basic_crawler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,29 @@ async def handler(_: BasicCrawlingContext) -> None:
18251825
await crawler_task
18261826

18271827

1828+
async def test_protect_request_in_run_handlers() -> None:
1829+
"""Test that request in crawling context are protected in run handlers."""
1830+
request_queue = await RequestQueue.open(name='state-test')
1831+
1832+
request = Request.from_url('https://test.url/', user_data={'request_state': ['initial']})
1833+
1834+
crawler = BasicCrawler(request_manager=request_queue, max_request_retries=0)
1835+
1836+
@crawler.router.default_handler
1837+
async def handler(context: BasicCrawlingContext) -> None:
1838+
if isinstance(context.request.user_data['request_state'], list):
1839+
context.request.user_data['request_state'].append('modified')
1840+
raise ValueError('Simulated error after modifying request')
1841+
1842+
await crawler.run([request])
1843+
1844+
check_request = await request_queue.get_request(request.unique_key)
1845+
assert check_request is not None
1846+
assert check_request.user_data['request_state'] == ['initial']
1847+
1848+
await request_queue.drop()
1849+
1850+
18281851
async def test_new_request_error_handler() -> None:
18291852
"""Test that error in new_request_handler is handled properly."""
18301853
queue = await RequestQueue.open()

0 commit comments

Comments
 (0)