|
3 | 3 | import collections
|
4 | 4 | import itertools
|
5 | 5 | import operator
|
6 |
| -from typing import TYPE_CHECKING, Collection, Generic, Iterable, Mapping |
| 6 | +from typing import TYPE_CHECKING, Generic |
7 | 7 |
|
8 | 8 | from ..structs import (
|
9 | 9 | CT,
|
|
27 | 27 | )
|
28 | 28 |
|
29 | 29 | if TYPE_CHECKING:
|
| 30 | + from collections.abc import Collection, Iterable, Mapping |
| 31 | + |
30 | 32 | from ..providers import AbstractProvider, Preference
|
31 | 33 | from ..reporters import BaseReporter
|
32 | 34 |
|
| 35 | +_OPTIMISTIC_BACKJUMPING_RATIO: float = 0.1 |
| 36 | + |
33 | 37 |
|
34 | 38 | def _build_result(state: State[RT, CT, KT]) -> Result[RT, CT, KT]:
|
35 | 39 | mapping = state.mapping
|
@@ -77,6 +81,11 @@ def __init__(
|
77 | 81 | self._r = reporter
|
78 | 82 | self._states: list[State[RT, CT, KT]] = []
|
79 | 83 |
|
| 84 | + # Optimistic backjumping variables |
| 85 | + self._optimistic_backjumping_ratio = _OPTIMISTIC_BACKJUMPING_RATIO |
| 86 | + self._save_states: list[State[RT, CT, KT]] | None = None |
| 87 | + self._optimistic_start_round: int | None = None |
| 88 | + |
80 | 89 | @property
|
81 | 90 | def state(self) -> State[RT, CT, KT]:
|
82 | 91 | try:
|
@@ -274,6 +283,25 @@ def _patch_criteria(
|
274 | 283 | )
|
275 | 284 | return True
|
276 | 285 |
|
| 286 | + def _save_state(self) -> None: |
| 287 | + """Save states for potential rollback if optimistic backjumping fails.""" |
| 288 | + if self._save_states is None: |
| 289 | + self._save_states = [ |
| 290 | + State( |
| 291 | + mapping=s.mapping.copy(), |
| 292 | + criteria=s.criteria.copy(), |
| 293 | + backtrack_causes=s.backtrack_causes[:], |
| 294 | + ) |
| 295 | + for s in self._states |
| 296 | + ] |
| 297 | + |
| 298 | + def _rollback_states(self) -> None: |
| 299 | + """Rollback states and disable optimistic backjumping.""" |
| 300 | + self._optimistic_backjumping_ratio = 0.0 |
| 301 | + if self._save_states: |
| 302 | + self._states = self._save_states |
| 303 | + self._save_states = None |
| 304 | + |
277 | 305 | def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool:
|
278 | 306 | """Perform backjumping.
|
279 | 307 |
|
@@ -324,13 +352,26 @@ def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool:
|
324 | 352 | except (IndexError, KeyError):
|
325 | 353 | raise ResolutionImpossible(causes) from None
|
326 | 354 |
|
327 |
| - # Only backjump if the current broken state is |
328 |
| - # an incompatible dependency |
329 |
| - if name not in incompatible_deps: |
| 355 | + if ( |
| 356 | + not self._optimistic_backjumping_ratio |
| 357 | + and name not in incompatible_deps |
| 358 | + ): |
| 359 | + # For safe backjumping only backjump if the current dependency |
| 360 | + # is not the same as the incompatible dependency |
330 | 361 | break
|
331 | 362 |
|
| 363 | + # On the first time a non-safe backjump is done the state |
| 364 | + # is saved so we can restore it later if the resolution fails |
| 365 | + if ( |
| 366 | + self._optimistic_backjumping_ratio |
| 367 | + and self._save_states is None |
| 368 | + and name not in incompatible_deps |
| 369 | + ): |
| 370 | + self._save_state() |
| 371 | + |
332 | 372 | # If the current dependencies and the incompatible dependencies
|
333 |
| - # are overlapping then we have found a cause of the incompatibility |
| 373 | + # are overlapping then we have likely found a cause of the |
| 374 | + # incompatibility |
334 | 375 | current_dependencies = {
|
335 | 376 | self._p.identify(d) for d in self._p.get_dependencies(candidate)
|
336 | 377 | }
|
@@ -394,9 +435,32 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT,
|
394 | 435 | # pinning the virtual "root" package in the graph.
|
395 | 436 | self._push_new_state()
|
396 | 437 |
|
| 438 | + # Variables for optimistic backjumping |
| 439 | + optimistic_rounds_cutoff: int | None = None |
| 440 | + optimistic_backjumping_start_round: int | None = None |
| 441 | + |
397 | 442 | for round_index in range(max_rounds):
|
398 | 443 | self._r.starting_round(index=round_index)
|
399 | 444 |
|
| 445 | + # Handle if optimistic backjumping has been running for too long |
| 446 | + if self._optimistic_backjumping_ratio and self._save_states is not None: |
| 447 | + if optimistic_backjumping_start_round is None: |
| 448 | + optimistic_backjumping_start_round = round_index |
| 449 | + optimistic_rounds_cutoff = int( |
| 450 | + (max_rounds - round_index) * self._optimistic_backjumping_ratio |
| 451 | + ) |
| 452 | + |
| 453 | + if optimistic_rounds_cutoff <= 0: |
| 454 | + self._rollback_states() |
| 455 | + continue |
| 456 | + elif optimistic_rounds_cutoff is not None: |
| 457 | + if ( |
| 458 | + round_index - optimistic_backjumping_start_round |
| 459 | + >= optimistic_rounds_cutoff |
| 460 | + ): |
| 461 | + self._rollback_states() |
| 462 | + continue |
| 463 | + |
400 | 464 | unsatisfied_names = [
|
401 | 465 | key
|
402 | 466 | for key, criterion in self.state.criteria.items()
|
@@ -448,12 +512,29 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT,
|
448 | 512 | # Backjump if pinning fails. The backjump process puts us in
|
449 | 513 | # an unpinned state, so we can work on it in the next round.
|
450 | 514 | self._r.resolving_conflicts(causes=causes)
|
451 |
| - success = self._backjump(causes) |
452 |
| - self.state.backtrack_causes[:] = causes |
453 | 515 |
|
454 |
| - # Dead ends everywhere. Give up. |
455 |
| - if not success: |
456 |
| - raise ResolutionImpossible(self.state.backtrack_causes) |
| 516 | + try: |
| 517 | + success = self._backjump(causes) |
| 518 | + except ResolutionImpossible: |
| 519 | + if self._optimistic_backjumping_ratio and self._save_states: |
| 520 | + failed_optimistic_backjumping = True |
| 521 | + else: |
| 522 | + raise |
| 523 | + else: |
| 524 | + failed_optimistic_backjumping = bool( |
| 525 | + not success |
| 526 | + and self._optimistic_backjumping_ratio |
| 527 | + and self._save_states |
| 528 | + ) |
| 529 | + |
| 530 | + if failed_optimistic_backjumping and self._save_states: |
| 531 | + self._rollback_states() |
| 532 | + else: |
| 533 | + self.state.backtrack_causes[:] = causes |
| 534 | + |
| 535 | + # Dead ends everywhere. Give up. |
| 536 | + if not success: |
| 537 | + raise ResolutionImpossible(self.state.backtrack_causes) |
457 | 538 | else:
|
458 | 539 | # discard as information sources any invalidated names
|
459 | 540 | # (unsatisfied names that were previously satisfied)
|
|
0 commit comments