13
13
import traceback
14
14
import warnings
15
15
from contextlib import suppress
16
+ from datetime import datetime , timedelta
16
17
from typing import TYPE_CHECKING , Any , Callable
17
18
18
19
import loky
@@ -129,10 +130,11 @@ def __init__(
129
130
shutdown_executor = False ,
130
131
retries = 0 ,
131
132
raise_if_retries_exceeded = True ,
133
+ allow_running_forever = False ,
132
134
):
133
135
134
136
self .executor = _ensure_executor (executor )
135
- self .goal = goal
137
+ self .goal = auto_goal ( goal , learner , allow_running_forever )
136
138
137
139
self ._max_tasks = ntasks
138
140
@@ -396,6 +398,7 @@ def __init__(
396
398
shutdown_executor = shutdown_executor ,
397
399
retries = retries ,
398
400
raise_if_retries_exceeded = raise_if_retries_exceeded ,
401
+ allow_running_forever = False ,
399
402
)
400
403
self ._run ()
401
404
@@ -518,11 +521,6 @@ def __init__(
518
521
raise_if_retries_exceeded = True ,
519
522
):
520
523
521
- if goal is None :
522
-
523
- def goal (_ ):
524
- return False
525
-
526
524
if (
527
525
executor is None
528
526
and _default_executor is concurrent .ProcessPoolExecutor
@@ -548,6 +546,7 @@ def goal(_):
548
546
shutdown_executor = shutdown_executor ,
549
547
retries = retries ,
550
548
raise_if_retries_exceeded = raise_if_retries_exceeded ,
549
+ allow_running_forever = True ,
551
550
)
552
551
self .ioloop = ioloop or asyncio .get_event_loop ()
553
552
self .task = None
@@ -861,3 +860,89 @@ def _get_ncores(ex):
861
860
return ex ._pool .size # not public API!
862
861
else :
863
862
raise TypeError (f"Cannot get number of cores for { ex .__class__ } " )
863
+
864
+
865
+ class _TimeGoal :
866
+ def __init__ (self , dt : timedelta | datetime ):
867
+ self .dt = dt
868
+ self .start_time = None
869
+
870
+ def __call__ (self , _ ):
871
+ if isinstance (self .dt , timedelta ):
872
+ if self .start_time is None :
873
+ self .start_time = datetime .now ()
874
+ return datetime .now () - self .start_time > self .dt
875
+ elif isinstance (self .dt , datetime ):
876
+ return datetime .now () > self .dt
877
+ else :
878
+ raise TypeError (f"{ self .dt = } is not a datetime or timedelta." )
879
+
880
+
881
+ def auto_goal (
882
+ goal : Callable [[BaseLearner ], bool ] | int | float | datetime | timedelta | None ,
883
+ learner : BaseLearner ,
884
+ allow_running_forever : bool = True ,
885
+ ):
886
+ """Extract a goal from the learners.
887
+
888
+ Parameters
889
+ ----------
890
+ goal
891
+ The goal to extract. Can be a callable, an integer, a float, a datetime,
892
+ a timedelta or None.
893
+ If it is a callable, it is returned as is.
894
+ If it is an integer, the goal is reached after that many points have been
895
+ returned.
896
+ If it is a float, the goal is reached when the learner has reached a loss
897
+ less than that.
898
+ If it is a datetime, the goal is reached when the current time is after the
899
+ datetime.
900
+ If it is a timedelta, the goal is reached when the current time is after
901
+ the start time plus that timedelta.
902
+ If it is None, and
903
+ - the learner type is `adaptive.SequenceLearner`, it continues until
904
+ it no more points to add
905
+ - the learner type is `adaptive.Integrator`, it continues until the
906
+ error is less than the tolerance.
907
+ - otherwise, it continues forever, unless `allow_running_forever` is
908
+ False, in which case it raises a ValueError.
909
+ learner
910
+ Learner for which to determine the goal.
911
+ allow_running_forever
912
+ If True, and the goal is None and the learner is not a SequenceLearner,
913
+ then a goal that never stops is returned, otherwise an exception is raised.
914
+
915
+ Returns
916
+ -------
917
+ Callable[[adaptive.BaseLearner], bool]
918
+ """
919
+ from adaptive import BalancingLearner , IntegratorLearner , SequenceLearner
920
+
921
+ if callable (goal ):
922
+ return goal
923
+ if isinstance (goal , float ):
924
+ return lambda learner : learner .loss () <= goal
925
+ if isinstance (learner , BalancingLearner ):
926
+ # Note that the float loss goal is more efficiently implemented in the
927
+ # BalancingLearner itself. That is why the previous if statement is
928
+ # above this one.
929
+ goals = [auto_goal (goal , l , allow_running_forever ) for l in learner .learners ]
930
+ return lambda learner : all (goal (l ) for l , goal in zip (learner .learners , goals ))
931
+ if isinstance (goal , int ):
932
+ return lambda learner : learner .npoints >= goal
933
+ if isinstance (goal , (timedelta , datetime )):
934
+ return _TimeGoal (goal )
935
+ if goal is None :
936
+ if isinstance (learner , SequenceLearner ):
937
+ return SequenceLearner .done
938
+ if isinstance (learner , IntegratorLearner ):
939
+ return IntegratorLearner .done
940
+ warnings .warn ("Goal is None which means the learners continue forever!" )
941
+ if allow_running_forever :
942
+ return lambda _ : False
943
+ else :
944
+ raise ValueError (
945
+ "Goal is None which means the learners"
946
+ " continue forever and this is not allowed."
947
+ )
948
+ raise ValueError ("Cannot determine goal from {goal}." )
0 commit comments