Skip to content

Commit dd35256

Browse files
committed
isolate request and session in request handler
1 parent d44aa89 commit dd35256

File tree

4 files changed

+61
-40
lines changed

4 files changed

+61
-40
lines changed

src/crawlee/_types.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import dataclasses
44
from collections.abc import Callable, Iterator, Mapping
5+
from copy import deepcopy
56
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypedDict, TypeVar, cast, overload
7+
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, Protocol, TypedDict, TypeVar, cast, overload
78

89
from pydantic import ConfigDict, Field, PlainValidator, RootModel
910

@@ -260,12 +261,27 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None:
260261
class RequestHandlerRunResult:
261262
"""Record of calls to storage-related context helpers."""
262263

263-
def __init__(self, *, key_value_store_getter: GetKeyValueStoreFunction) -> None:
264+
_REQUEST_SYNC_FIELDS: ClassVar[frozenset[str]] = frozenset({'headers', 'user_data'})
265+
_SESSION_SYNC_FIELDS: ClassVar[frozenset[str]] = frozenset(
266+
{'_user_data', '_usage_count', '_error_score', '_cookies'}
267+
)
268+
269+
def __init__(
270+
self,
271+
*,
272+
key_value_store_getter: GetKeyValueStoreFunction,
273+
request: Request,
274+
session: Session | None = None,
275+
) -> None:
264276
self._key_value_store_getter = key_value_store_getter
265277
self.add_requests_calls = list[AddRequestsKwargs]()
266278
self.push_data_calls = list[PushDataFunctionCall]()
267279
self.key_value_store_changes = dict[tuple[str | None, str | None, str | None], KeyValueStoreChangeRecords]()
268280

281+
# Isolated copies for handler execution
282+
self.request = deepcopy(request)
283+
self.session = deepcopy(session) if session else None
284+
269285
async def add_requests(
270286
self,
271287
requests: Sequence[str | Request],
@@ -315,6 +331,23 @@ async def get_key_value_store(
315331

316332
return self.key_value_store_changes[id, name, alias]
317333

334+
def sync_request(self, sync_request: Request) -> None:
335+
"""Sync request state from copies back to originals."""
336+
for field in self._REQUEST_SYNC_FIELDS:
337+
value = getattr(self.request, field)
338+
original_value = getattr(sync_request, field)
339+
if value != original_value:
340+
object.__setattr__(sync_request, field, value)
341+
342+
def sync_session(self, sync_session: Session | None = None) -> None:
343+
"""Sync session state from copies back to originals."""
344+
if self.session and sync_session:
345+
for field in self._SESSION_SYNC_FIELDS:
346+
value = getattr(self.session, field)
347+
original_value = getattr(sync_session, field)
348+
if value != original_value:
349+
object.__setattr__(sync_session, field, value)
350+
318351

319352
@docs_group('Functions')
320353
class AddRequestsFunction(Protocol):

src/crawlee/crawlers/_adaptive_playwright/_adaptive_playwright_crawler.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -290,10 +290,12 @@ async def get_input_state(
290290
use_state_function = context.use_state
291291

292292
# New result is created and injected to newly created context. This is done to ensure isolation of sub crawlers.
293-
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
293+
result = RequestHandlerRunResult(
294+
key_value_store_getter=self.get_key_value_store, request=context.request, session=context.session
295+
)
294296
context_linked_to_result = BasicCrawlingContext(
295-
request=deepcopy(context.request),
296-
session=deepcopy(context.session),
297+
request=result.request,
298+
session=result.session,
297299
proxy_info=deepcopy(context.proxy_info),
298300
send_request=context.send_request,
299301
add_requests=result.add_requests,
@@ -314,7 +316,7 @@ async def get_input_state(
314316
),
315317
logger=self._logger,
316318
)
317-
return SubCrawlerRun(result=result, run_context=context_linked_to_result)
319+
return SubCrawlerRun(result=result)
318320
except Exception as e:
319321
return SubCrawlerRun(exception=e)
320322

@@ -370,8 +372,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
370372
self.track_http_only_request_handler_runs()
371373

372374
static_run = await self._crawl_one(rendering_type='static', context=context)
373-
if static_run.result and static_run.run_context and self.result_checker(static_run.result):
374-
self._update_context_from_copy(context, static_run.run_context)
375+
if static_run.result and self.result_checker(static_run.result):
375376
self._context_result_map[context] = static_run.result
376377
return
377378
if static_run.exception:
@@ -402,7 +403,7 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
402403
if pw_run.exception is not None:
403404
raise pw_run.exception
404405

405-
if pw_run.result and pw_run.run_context:
406+
if pw_run.result:
406407
if should_detect_rendering_type:
407408
detection_result: RenderingType
408409
static_run = await self._crawl_one('static', context=context, state=old_state_copy)
@@ -414,7 +415,6 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
414415
context.log.debug(f'Detected rendering type {detection_result} for {context.request.url}')
415416
self.rendering_type_predictor.store_result(context.request, detection_result)
416417

417-
self._update_context_from_copy(context, pw_run.run_context)
418418
self._context_result_map[context] = pw_run.result
419419

420420
def pre_navigation_hook(
@@ -451,32 +451,8 @@ def track_browser_request_handler_runs(self) -> None:
451451
def track_rendering_type_mispredictions(self) -> None:
452452
self.statistics.state.rendering_type_mispredictions += 1
453453

454-
def _update_context_from_copy(self, context: BasicCrawlingContext, context_copy: BasicCrawlingContext) -> None:
455-
"""Update mutable fields of `context` from `context_copy`.
456-
457-
Uses object.__setattr__ to bypass frozen dataclass restrictions,
458-
allowing state synchronization after isolated crawler execution.
459-
"""
460-
updating_attributes = {
461-
'request': ('headers', 'user_data'),
462-
'session': ('_user_data', '_usage_count', '_error_score', '_cookies'),
463-
}
464-
465-
for attr, sub_attrs in updating_attributes.items():
466-
original_sub_obj = getattr(context, attr)
467-
copy_sub_obj = getattr(context_copy, attr)
468-
469-
# Check that both sub objects are not None
470-
if original_sub_obj is None or copy_sub_obj is None:
471-
continue
472-
473-
for sub_attr in sub_attrs:
474-
new_value = getattr(copy_sub_obj, sub_attr)
475-
object.__setattr__(original_sub_obj, sub_attr, new_value)
476-
477454

478455
@dataclass(frozen=True)
479456
class SubCrawlerRun:
480457
result: RequestHandlerRunResult | None = None
481458
exception: Exception | None = None
482-
run_context: BasicCrawlingContext | None = None

src/crawlee/crawlers/_basic/_basic_crawler.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,12 @@ async def _add_requests(
13121312

13131313
return await request_manager.add_requests(context_aware_requests)
13141314

1315-
async def _commit_request_handler_result(self, context: BasicCrawlingContext) -> None:
1315+
async def _commit_request_handler_result(
1316+
self,
1317+
context: BasicCrawlingContext,
1318+
original_request: Request,
1319+
original_session: Session | None = None,
1320+
) -> None:
13161321
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
13171322
result = self._context_result_map[context]
13181323

@@ -1324,6 +1329,9 @@ async def _commit_request_handler_result(self, context: BasicCrawlingContext) ->
13241329

13251330
await self._commit_key_value_store_changes(result, get_kvs=self.get_key_value_store)
13261331

1332+
result.sync_session(sync_session=original_session)
1333+
result.sync_request(sync_request=original_request)
1334+
13271335
@staticmethod
13281336
async def _commit_key_value_store_changes(
13291337
result: RequestHandlerRunResult, get_kvs: GetKeyValueStoreFromRequestHandlerFunction
@@ -1389,11 +1397,13 @@ async def __run_task_function(self) -> None:
13891397
else:
13901398
session = await self._get_session()
13911399
proxy_info = await self._get_proxy_info(request, session)
1392-
result = RequestHandlerRunResult(key_value_store_getter=self.get_key_value_store)
1400+
result = RequestHandlerRunResult(
1401+
key_value_store_getter=self.get_key_value_store, request=request, session=session
1402+
)
13931403

13941404
context = BasicCrawlingContext(
1395-
request=request,
1396-
session=session,
1405+
request=result.request,
1406+
session=result.session,
13971407
proxy_info=proxy_info,
13981408
send_request=self._prepare_send_request_function(session, proxy_info),
13991409
add_requests=result.add_requests,
@@ -1416,9 +1426,9 @@ async def __run_task_function(self) -> None:
14161426
except asyncio.TimeoutError as e:
14171427
raise RequestHandlerError(e, context) from e
14181428

1419-
await self._commit_request_handler_result(context)
1429+
await self._commit_request_handler_result(context, original_request=request, original_session=session)
14201430
await wait_for(
1421-
lambda: request_manager.mark_request_as_handled(context.request),
1431+
lambda: request_manager.mark_request_as_handled(request),
14221432
timeout=self._internal_timeout,
14231433
timeout_message='Marking request as handled timed out after '
14241434
f'{self._internal_timeout.total_seconds()} seconds',

tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,8 @@ async def request_handler(context: AdaptivePlaywrightCrawlingContext) -> None:
802802

803803
assert session is not None
804804
assert check_request is not None
805+
806+
print('Test Session', check_request)
805807
assert session.user_data.get('session_state') is True
806808
# Check that request user data was updated in the handler and only onse.
807809
assert check_request.user_data.get('request_state') == ['initial', 'handler']

0 commit comments

Comments
 (0)