Skip to content

Commit 7f2bdc8

Browse files
improve defaults for objective_higher_is_better
1 parent a67a104 commit 7f2bdc8

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

kernel_tuner/interface.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import numpy
3131
from time import perf_counter
3232

33+
from kernel_tuner.integration import get_objective_defaults
34+
3335
import kernel_tuner.util as util
3436
import kernel_tuner.core as core
3537

@@ -449,7 +451,7 @@ def _get_docstring(opts):
449451
def tune_kernel(kernel_name, kernel_source, problem_size, arguments, tune_params, grid_div_x=None, grid_div_y=None, grid_div_z=None, restrictions=None,
450452
answer=None, atol=1e-6, verify=None, verbose=False, lang=None, device=0, platform=0, smem_args=None, cmem_args=None, texmem_args=None,
451453
compiler=None, compiler_options=None, log=None, iterations=7, block_size_names=None, quiet=False, strategy=None, strategy_options=None,
452-
cache=None, metrics=None, simulation_mode=False, observers=None, objective=None, objective_higher_is_better=False):
454+
cache=None, metrics=None, simulation_mode=False, observers=None, objective=None, objective_higher_is_better=None):
453455
start_overhead_time = perf_counter()
454456
if log:
455457
logging.basicConfig(filename=kernel_name + datetime.now().strftime('%Y%m%d-%H:%M:%S') + '.log', level=log)
@@ -459,8 +461,7 @@ def tune_kernel(kernel_name, kernel_source, problem_size, arguments, tune_params
459461
_check_user_input(kernel_name, kernelsource, arguments, block_size_names)
460462

461463
# default objective if none is specified
462-
if objective is None:
463-
objective = "time"
464+
objective, objective_higher_is_better = get_objective_defaults(objective, objective_higher_is_better)
464465

465466
# check for forbidden names in tune parameters
466467
util.check_tune_params_list(tune_params)

0 commit comments

Comments
 (0)