Skip to content

Commit 22a6e0e

Browse files
committed
returns the fraction of the budget that has been spent
1 parent 5ce2495 commit 22a6e0e

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

kernel_tuner/strategies/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(self, searchspace: Searchspace, tuning_options, runner, *, scaling=
6060
self.scaling = scaling
6161
self.searchspace = searchspace
6262
self.results = []
63+
self.budget_spent_fraction = 0.0
6364

6465
def __call__(self, x, check_restrictions=True):
6566
"""Cost function used by almost all strategies."""
@@ -70,7 +71,7 @@ def __call__(self, x, check_restrictions=True):
7071
logging.debug('x: ' + str(x))
7172

7273
# check if max_fevals is reached or time limit is exceeded
73-
util.check_stop_criterion(self.tuning_options)
74+
self.budget_spent_fraction = util.check_stop_criterion(self.tuning_options)
7475

7576
# snap values in x to nearest actual value for each parameter, unscale x if needed
7677
if self.snap:

kernel_tuner/util.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,12 +189,28 @@ def check_argument_list(kernel_name, kernel_string, args):
189189
warnings.warn(errors[0], UserWarning)
190190

191191

192-
def check_stop_criterion(to):
193-
"""Checks if max_fevals is reached or time limit is exceeded."""
194-
if "max_fevals" in to and len(to.unique_results) >= to.max_fevals:
195-
raise StopCriterionReached("max_fevals reached")
196-
if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit):
197-
raise StopCriterionReached("time limit exceeded")
192+
def check_stop_criterion(to: dict) -> float:
193+
"""Check if the stop criterion is reached.
194+
195+
Args:
196+
to (dict): tuning options.
197+
198+
Raises:
199+
StopCriterionReached: if the max_fevals is reached or time limit is exceeded.
200+
201+
Returns:
202+
float: fraction of budget spent.
203+
"""
204+
if "max_fevals" in to:
205+
if len(to.unique_results) >= to.max_fevals:
206+
raise StopCriterionReached(f"max_fevals ({to.max_fevals}) reached")
207+
return len(to.unique_results) / to.max_fevals
208+
if "time_limit" in to:
209+
time_spent = (time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)
210+
if time_spent > to.time_limit:
211+
raise StopCriterionReached("time limit exceeded")
212+
return time_spent / to.time_limit
213+
198214

199215

200216
def check_tune_params_list(tune_params, observers, simulation_mode=False):

0 commit comments

Comments
 (0)