Skip to content

Commit 07b599f

Browse files
committed
Fix CostFunc returning results containing InvalidConfig
1 parent da4fa97 commit 07b599f

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

kernel_tuner/strategies/common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _run_configs(self, xs, check_restrictions=True):
146146
batch_keys = [] # The keys of the configs to run
147147
pending_indices_by_key = dict() # Maps key => where to store result in `final_results`
148148
final_results = [] # List returned to the user
149+
legal_indices = [] # Indices in `final_results` that are legal
149150

150151
# Loop over all configurations. For each configurations there are four cases:
151152
# 1. The configuration is invalid, we can skip it
@@ -166,6 +167,7 @@ def _run_configs(self, xs, check_restrictions=True):
166167
# 2. Attempt to retrieve from `unique_results`
167168
elif key in self.unique_results:
168169
result = dict(self.unique_results[key])
170+
legal_indices.append(len(final_results))
169171
final_results.append(result)
170172

171173
# 3. We have already seen this config in the current batch
@@ -190,6 +192,7 @@ def _run_configs(self, xs, check_restrictions=True):
190192

191193
# set in the results array
192194
for index in pending_indices_by_key[key]:
195+
legal_indices.append(index)
193196
final_results[index] = dict(result)
194197

195198
# Disable the timings. Only the first result must get these.
@@ -200,10 +203,9 @@ def _run_configs(self, xs, check_restrictions=True):
200203
# Put result in `unique_results`
201204
self.unique_results[key] = result
202205

203-
for result in final_results:
204-
# Skip if None. Result is missing if runner exhausted the budget
205-
if result is not None:
206-
self.results.append(result)
206+
# Only things in `legal_indices` are valid results
207+
for index in sorted(legal_indices):
208+
self.results.append(result)
207209

208210
# upon returning from this function control will be given back to the strategy, so reset the start time
209211
self.runner.last_strategy_start_time = perf_counter()

kernel_tuner/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ def process_metrics(params, metrics):
732732
:rtype: dict
733733
734734
"""
735-
if metrics:
735+
if metrics is not None:
736736
if not isinstance(metrics, dict):
737737
raise ValueError("metrics should be a dictionary to preserve order and support composability")
738738
for k, v in metrics.items():

test/strategies/test_strategies.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def vector_add():
5353
strategies.append(pytest.param(s, marks=skip_if_no_pyatf))
5454
else:
5555
strategies.append(s)
56+
5657
@pytest.mark.parametrize('strategy', strategies)
5758
def test_strategies(vector_add, strategy):
5859
options = dict(popsize=5, neighbor='adjacent')

test/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
import kernel_tuner.strategies.common as common
7+
import kernel_tuner.util
78
from kernel_tuner.interface import Options
89
from kernel_tuner.searchspace import Searchspace
910

@@ -13,6 +14,7 @@ def tuning_options():
1314
tuning_options["strategy_options"] = {}
1415
tuning_options["objective"] = "time"
1516
tuning_options["objective_higher_is_better"] = False
17+
tuning_options["budget"] = kernel_tuner.util.TuningBudget()
1618
return tuning_options
1719

1820

0 commit comments

Comments
 (0)