Skip to content

Commit 6b014a7

Browse files
committed
replace result request and session in context by originals
1 parent dd35256 commit 6b014a7

File tree

3 files changed

+75
-38
lines changed

3 files changed

+75
-38
lines changed

src/crawlee/_types.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Callable, Iterator, Mapping
55
from copy import deepcopy
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, Protocol, TypedDict, TypeVar, cast, overload
7+
from typing import TYPE_CHECKING, Annotated, Any, Literal, Protocol, TypedDict, TypeVar, cast, overload
88

99
from pydantic import ConfigDict, Field, PlainValidator, RootModel
1010

@@ -261,11 +261,6 @@ async def get_value(self, key: str, default_value: T | None = None) -> T | None:
261261
class RequestHandlerRunResult:
262262
"""Record of calls to storage-related context helpers."""
263263

264-
_REQUEST_SYNC_FIELDS: ClassVar[frozenset[str]] = frozenset({'headers', 'user_data'})
265-
_SESSION_SYNC_FIELDS: ClassVar[frozenset[str]] = frozenset(
266-
{'_user_data', '_usage_count', '_error_score', '_cookies'}
267-
)
268-
269264
def __init__(
270265
self,
271266
*,
@@ -279,8 +274,16 @@ def __init__(
279274
self.key_value_store_changes = dict[tuple[str | None, str | None, str | None], KeyValueStoreChangeRecords]()
280275

281276
# Isolated copies for handler execution
282-
self.request = deepcopy(request)
283-
self.session = deepcopy(session) if session else None
277+
self._request = deepcopy(request)
278+
self._session = deepcopy(session) if session else None
279+
280+
@property
281+
def request(self) -> Request:
282+
return self._request
283+
284+
@property
285+
def session(self) -> Session | None:
286+
return self._session
284287

285288
async def add_requests(
286289
self,
@@ -331,22 +334,29 @@ async def get_key_value_store(
331334

332335
return self.key_value_store_changes[id, name, alias]
333336

334-
def sync_request(self, sync_request: Request) -> None:
335-
"""Sync request state from copies back to originals."""
336-
for field in self._REQUEST_SYNC_FIELDS:
337-
value = getattr(self.request, field)
338-
original_value = getattr(sync_request, field)
339-
if value != original_value:
340-
object.__setattr__(sync_request, field, value)
341-
342-
def sync_session(self, sync_session: Session | None = None) -> None:
343-
"""Sync session state from copies back to originals."""
344-
if self.session and sync_session:
345-
for field in self._SESSION_SYNC_FIELDS:
337+
def apply_request_changes(self, target: Request) -> None:
338+
"""Apply tracked changes from handler copy to original request."""
339+
if self.request.user_data != target.user_data:
340+
target.user_data.update(self.request.user_data)
341+
342+
if self.request.headers != target.headers:
343+
target.headers = target.headers | self.request.headers
344+
345+
def apply_session_changes(self, target: Session | None = None) -> None:
346+
"""Apply tracked changes from handler copy to original session."""
347+
simple_fields: set[str] = {'_usage_count', '_error_score'}
348+
349+
if self.session and target:
350+
if self.session.user_data != target.user_data:
351+
target.user_data.update(self.session.user_data)
352+
353+
if self.session.cookies != target.cookies:
354+
target.cookies.set_cookies(self.session.cookies.get_cookies_as_dicts())
355+
for field in simple_fields:
346356
value = getattr(self.session, field)
347-
original_value = getattr(sync_session, field)
357+
original_value = getattr(target, field)
348358
if value != original_value:
349-
object.__setattr__(sync_session, field, value)
359+
object.__setattr__(target, field, value)
350360

351361

352362
@docs_group('Functions')

src/crawlee/crawlers/_basic/_basic_crawler.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from crawlee.storages import Dataset, KeyValueStore, RequestQueue
6666

6767
from ._context_pipeline import ContextPipeline
68+
from ._context_utils import swaped_context
6869
from ._logging_utils import (
6970
get_one_line_error_summary_if_possible,
7071
reduce_asyncio_timeout_error_to_relevant_traceback_parts,
@@ -1315,8 +1316,6 @@ async def _add_requests(
13151316
async def _commit_request_handler_result(
13161317
self,
13171318
context: BasicCrawlingContext,
1318-
original_request: Request,
1319-
original_session: Session | None = None,
13201319
) -> None:
13211320
"""Commit request handler result for the input `context`. Result is taken from `_context_result_map`."""
13221321
result = self._context_result_map[context]
@@ -1329,8 +1328,8 @@ async def _commit_request_handler_result(
13291328

13301329
await self._commit_key_value_store_changes(result, get_kvs=self.get_key_value_store)
13311330

1332-
result.sync_session(sync_session=original_session)
1333-
result.sync_request(sync_request=original_request)
1331+
result.apply_session_changes(target=context.session)
1332+
result.apply_request_changes(target=context.request)
13341333

13351334
@staticmethod
13361335
async def _commit_key_value_store_changes(
@@ -1419,14 +1418,14 @@ async def __run_task_function(self) -> None:
14191418
try:
14201419
request.state = RequestState.REQUEST_HANDLER
14211420

1422-
self._check_request_collision(context.request, context.session)
1423-
14241421
try:
1425-
await self._run_request_handler(context=context)
1422+
with swaped_context(context, request, session):
1423+
self._check_request_collision(request, session)
1424+
await self._run_request_handler(context=context)
14261425
except asyncio.TimeoutError as e:
14271426
raise RequestHandlerError(e, context) from e
14281427

1429-
await self._commit_request_handler_result(context, original_request=request, original_session=session)
1428+
await self._commit_request_handler_result(context)
14301429
await wait_for(
14311430
lambda: request_manager.mark_request_as_handled(request),
14321431
timeout=self._internal_timeout,
@@ -1438,13 +1437,13 @@ async def __run_task_function(self) -> None:
14381437

14391438
request.state = RequestState.DONE
14401439

1441-
if context.session and context.session.is_usable:
1442-
context.session.mark_good()
1440+
if session and session.is_usable:
1441+
session.mark_good()
14431442

14441443
self._statistics.record_request_processing_finish(request.unique_key)
14451444

14461445
except RequestCollisionError as request_error:
1447-
context.request.no_retry = True
1446+
request.no_retry = True
14481447
await self._handle_request_error(context, request_error)
14491448

14501449
except RequestHandlerError as primary_error:
@@ -1459,7 +1458,7 @@ async def __run_task_function(self) -> None:
14591458
await self._handle_request_error(primary_error.crawling_context, primary_error.wrapped_exception)
14601459

14611460
except SessionError as session_error:
1462-
if not context.session:
1461+
if not session:
14631462
raise RuntimeError('SessionError raised in a crawling context without a session') from session_error
14641463

14651464
if self._error_handler:
@@ -1469,16 +1468,17 @@ async def __run_task_function(self) -> None:
14691468
exc_only = ''.join(traceback.format_exception_only(session_error)).strip()
14701469
self._logger.warning('Encountered "%s", rotating session and retrying...', exc_only)
14711470

1472-
context.session.retire()
1471+
if session:
1472+
session.retire()
14731473

14741474
# Increment session rotation count.
1475-
context.request.session_rotation_count = (context.request.session_rotation_count or 0) + 1
1475+
request.session_rotation_count = (request.session_rotation_count or 0) + 1
14761476

14771477
await request_manager.reclaim_request(request)
14781478
await self._statistics.error_tracker_retry.add(error=session_error, context=context)
14791479
else:
14801480
await wait_for(
1481-
lambda: request_manager.mark_request_as_handled(context.request),
1481+
lambda: request_manager.mark_request_as_handled(request),
14821482
timeout=self._internal_timeout,
14831483
timeout_message='Marking request as handled timed out after '
14841484
f'{self._internal_timeout.total_seconds()} seconds',
@@ -1493,7 +1493,7 @@ async def __run_task_function(self) -> None:
14931493
self._logger.debug('The context pipeline was interrupted', exc_info=interrupted_error)
14941494

14951495
await wait_for(
1496-
lambda: request_manager.mark_request_as_handled(context.request),
1496+
lambda: request_manager.mark_request_as_handled(request),
14971497
timeout=self._internal_timeout,
14981498
timeout_message='Marking request as handled timed out after '
14991499
f'{self._internal_timeout.total_seconds()} seconds',
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from collections.abc import Iterator
8+
9+
from crawlee._request import Request
10+
from crawlee.sessions import Session
11+
12+
from ._basic_crawling_context import BasicCrawlingContext
13+
14+
15+
@contextmanager
16+
def swaped_context(
17+
context: BasicCrawlingContext,
18+
request: Request,
19+
session: Session | None,
20+
) -> Iterator[None]:
21+
"""Replace context's isolated copies with originals after handler execution."""
22+
try:
23+
yield
24+
finally:
25+
# Restore original context state to avoid side effects between different handlers.
26+
object.__setattr__(context, 'request', request)
27+
object.__setattr__(context, 'session', session)

0 commit comments

Comments
 (0)