Skip to content

Commit 0f7d320

Browse files
committed
Allow timedelta_goal to be an int
1 parent be14576 commit 0f7d320

File tree

1 file changed

+15
-24
lines changed

1 file changed

+15
-24
lines changed

adaptive/runner.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class BaseRunner(metaclass=abc.ABCMeta):
8888
Convenience argument, use instead of ``goal``. The end condition for the
8989
calculation. Stop when the current time is larger or equal than this
9090
value.
91-
timedelta_goal : timedelta, optional
91+
timedelta_goal : timedelta or int, optional
9292
Convenience argument, use instead of ``goal``. The end condition for the
9393
calculation. Stop when the current time is larger or equal than
9494
``start_time + timedelta_goal``.
@@ -507,7 +507,7 @@ class AsyncRunner(BaseRunner):
507507
Convenience argument, use instead of ``goal``. The end condition for the
508508
calculation. Stop when the current time is larger or equal than this
509509
value.
510-
timedelta_goal : timedelta, optional
510+
timedelta_goal : timedelta or int, optional
511511
Convenience argument, use instead of ``goal``. The end condition for the
512512
calculation. Stop when the current time is larger or equal than
513513
``start_time + timedelta_goal``.
@@ -831,7 +831,7 @@ def simple(
831831
Convenience argument, use instead of ``goal``. The end condition for the
832832
calculation. Stop when the current time is larger or equal than this
833833
value.
834-
timedelta_goal : timedelta, optional
834+
timedelta_goal : timedelta or int, optional
835835
Convenience argument, use instead of ``goal``. The end condition for the
836836
calculation. Stop when the current time is larger or equal than
837837
``start_time + timedelta_goal``.
@@ -970,7 +970,9 @@ def _get_ncores(ex):
970970

971971

972972
class _TimeGoal:
973-
def __init__(self, dt: timedelta | datetime):
973+
def __init__(self, dt: timedelta | datetime | int):
974+
if isinstance(dt, int):
975+
self.dt = timedelta(seconds=dt)
974976
self.dt = dt
975977
self.start_time = None
976978

@@ -989,33 +991,22 @@ def auto_goal(
989991
loss: float | None = None,
990992
npoints: int | None = None,
991993
datetime: datetime | None = None,
992-
timedelta: timedelta | None = None,
994+
timedelta: timedelta | int | None = None,
993995
learner: BaseLearner | None = None,
994996
allow_running_forever: bool = True,
995997
) -> Callable[[BaseLearner], bool]:
996998
"""Extract a goal from the learners.
997999
9981000
Parameters
9991001
----------
1000-
goal
1001-
The goal to extract. Can be a callable, an integer, a float, a datetime,
1002-
a timedelta or None.
1003-
If the type of `goal` is:
1004-
* ``callable``, it is returned as is.
1005-
* ``int``, the goal is reached after that many points have been added.
1006-
* ``float``, the goal is reached when the learner has reached a loss
1007-
equal or less than that.
1008-
* `datetime.datetime`, the goal is reached when the current time is after the
1009-
datetime.
1010-
* `datetime.timedelta`, the goal is reached when the current time is after
1011-
the start time plus that timedelta.
1012-
* ``None`` and
1013-
* the learner type is `adaptive.SequenceLearner`, it continues until
1014-
it no more points to add
1015-
* the learner type is `adaptive.IntegratorLearner`, it continues until the
1016-
error is less than the tolerance specified in the learner.
1017-
* otherwise, it continues forever, unless ``allow_running_forever`` is
1018-
False, in which case it raises a ValueError.
1002+
loss
1003+
TODO
1004+
npoints
1005+
TODO
1006+
datetime
1007+
TODO
1008+
timedelta
1009+
TODO
10191010
learner
10201011
Learner for which to determine the goal.
10211012
allow_running_forever

0 commit comments

Comments
 (0)