Skip to content

Commit e7f2179

Browse files
committed
Add an auto_goal function and use it in the Runner
1 parent ae44fae commit e7f2179

File tree

1 file changed

+91
-6
lines changed

1 file changed

+91
-6
lines changed

adaptive/runner.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import traceback
1414
import warnings
1515
from contextlib import suppress
16+
from datetime import datetime, timedelta
1617
from typing import TYPE_CHECKING, Any, Callable
1718

1819
import loky
@@ -129,10 +130,11 @@ def __init__(
129130
shutdown_executor=False,
130131
retries=0,
131132
raise_if_retries_exceeded=True,
133+
allow_running_forever=False,
132134
):
133135

134136
self.executor = _ensure_executor(executor)
135-
self.goal = goal
137+
self.goal = auto_goal(goal, learner, allow_running_forever)
136138

137139
self._max_tasks = ntasks
138140

@@ -396,6 +398,7 @@ def __init__(
396398
shutdown_executor=shutdown_executor,
397399
retries=retries,
398400
raise_if_retries_exceeded=raise_if_retries_exceeded,
401+
allow_running_forever=False,
399402
)
400403
self._run()
401404

@@ -518,11 +521,6 @@ def __init__(
518521
raise_if_retries_exceeded=True,
519522
):
520523

521-
if goal is None:
522-
523-
def goal(_):
524-
return False
525-
526524
if (
527525
executor is None
528526
and _default_executor is concurrent.ProcessPoolExecutor
@@ -548,6 +546,7 @@ def goal(_):
548546
shutdown_executor=shutdown_executor,
549547
retries=retries,
550548
raise_if_retries_exceeded=raise_if_retries_exceeded,
549+
allow_running_forever=True,
551550
)
552551
self.ioloop = ioloop or asyncio.get_event_loop()
553552
self.task = None
@@ -861,3 +860,89 @@ def _get_ncores(ex):
861860
return ex._pool.size # not public API!
862861
else:
863862
raise TypeError(f"Cannot get number of cores for {ex.__class__}")
863+
864+
865+
class _TimeGoal:
866+
def __init__(self, dt: timedelta | datetime):
867+
self.dt = dt
868+
self.start_time = None
869+
870+
def __call__(self, _):
871+
if isinstance(self.dt, timedelta):
872+
if self.start_time is None:
873+
self.start_time = datetime.now()
874+
return datetime.now() - self.start_time > self.dt
875+
elif isinstance(self.dt, datetime):
876+
return datetime.now() > self.dt
877+
else:
878+
raise TypeError(f"{self.dt=} is not a datetime or timedelta.")
879+
880+
881+
def auto_goal(
882+
goal: Callable[[BaseLearner], bool] | int | float | datetime | timedelta | None,
883+
learner: BaseLearner,
884+
allow_running_forever: bool = True,
885+
):
886+
"""Extract a goal from the learners.
887+
888+
Parameters
889+
----------
890+
goal
891+
The goal to extract. Can be a callable, an integer, a float, a datetime,
892+
a timedelta or None.
893+
If it is a callable, it is returned as is.
894+
If it is an integer, the goal is reached after that many points have been
895+
returned.
896+
If it is a float, the goal is reached when the learner has reached a loss
897+
less than that.
898+
If it is a datetime, the goal is reached when the current time is after the
899+
datetime.
900+
If it is a timedelta, the goal is reached when the current time is after
901+
the start time plus that timedelta.
902+
If it is None, and
903+
- the learner type is `adaptive.SequenceLearner`, it continues until
904+
it no more points to add
905+
- the learner type is `adaptive.Integrator`, it continues until the
906+
error is less than the tolerance.
907+
- otherwise, it continues forever, unless `allow_running_forever` is
908+
False, in which case it raises a ValueError.
909+
learner
910+
Learner for which to determine the goal.
911+
allow_running_forever
912+
If True, and the goal is None and the learner is not a SequenceLearner,
913+
then a goal that never stops is returned, otherwise an exception is raised.
914+
915+
Returns
916+
-------
917+
Callable[[adaptive.BaseLearner], bool]
918+
"""
919+
from adaptive import BalancingLearner, IntegratorLearner, SequenceLearner
920+
921+
if callable(goal):
922+
return goal
923+
if isinstance(goal, float):
924+
return lambda learner: learner.loss() <= goal
925+
if isinstance(learner, BalancingLearner):
926+
# Note that the float loss goal is more efficiently implemented in the
927+
# BalancingLearner itself. That is why the previous if statement is
928+
# above this one.
929+
goals = [auto_goal(goal, l, allow_running_forever) for l in learner.learners]
930+
return lambda learner: all(goal(l) for l, goal in zip(learner.learners, goals))
931+
if isinstance(goal, int):
932+
return lambda learner: learner.npoints >= goal
933+
if isinstance(goal, (timedelta, datetime)):
934+
return _TimeGoal(goal)
935+
if goal is None:
936+
if isinstance(learner, SequenceLearner):
937+
return SequenceLearner.done
938+
if isinstance(learner, IntegratorLearner):
939+
return IntegratorLearner.done
940+
warnings.warn("Goal is None which means the learners continue forever!")
941+
if allow_running_forever:
942+
return lambda _: False
943+
else:
944+
raise ValueError(
945+
"Goal is None which means the learners"
946+
" continue forever and this is not allowed."
947+
)
948+
raise ValueError("Cannot determine goal from {goal}.")

0 commit comments

Comments
 (0)