Skip to content

Commit ceaa96c

Browse files
committed
Move check for non-unique configuration from CostFunc to runners
1 parent a1c87db commit ceaa96c

File tree

4 files changed

+93
-86
lines changed

4 files changed

+93
-86
lines changed

kernel_tuner/runners/parallel.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -291,43 +291,53 @@ def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
291291
jobs = [] # Jobs that need to be executed
292292
results = [] # Results that will be returned at the end
293293
key2index = dict() # Used to insert job result back into `results`
294-
295-
total_worker_time = 0
294+
duplicate_entries = [] # Stores (i, j) if `i` is a duplicate of `j`.
296295

297296
# Select jobs which are not in the cache
298297
for index, config in enumerate(parameter_space):
299298
params = dict(zip(tuning_options.tune_params.keys(), config))
300299
key = ",".join([str(i) for i in config])
301300

301+
# Element is in cache
302302
if key in tuning_options.cache:
303-
cache_entry = tuning_options.cache[key]
304-
305303
# We must disable the timings as otherwise these will counted
306304
# as part of the total_compile/benchmark/verification_time
307-
results.append(disable_benchmark_timings(cache_entry))
305+
result = disable_benchmark_timings(tuning_options.cache[key])
306+
307+
# recompute matrics for this entry
308+
result = process_metrics(result, metrics)
309+
310+
results.append(result)
311+
312+
# Element is duplicate entry in `parameter_space`
313+
elif key in key2index:
314+
duplicate_entries.append((index, key2index[key]))
315+
results.append(None)
316+
317+
# Element must become a job
308318
else:
309-
assert key not in key2index, "duplicate jobs submitted"
310319
key2index[key] = index
311-
312320
jobs.append((key, params))
313321
results.append(None)
314322

323+
total_worker_time = 0
324+
315325
# Submit jobs and wait for them to finish
316326
for key, result in self.submit_jobs(jobs, tuning_options.budget):
317327
# `None` indicate that no result is available since the budget is exceeded.
318328
# We can skip it, meaning that `results` contains `None`s for these entries
319329
if result is None:
320330
continue
321331

322-
# Store the result into the output array
323-
results[key2index[key]] = result
324-
325332
# Collect total time spent by worker
326333
total_worker_time += (
327334
result["compile_time"] + result["verification_time"] + result["benchmark_time"]
328335
)
329336

330-
if isinstance(result.get(objective), ErrorConfig):
337+
# only compute metrics on configs that have not errored
338+
if not isinstance(result.get(objective), ErrorConfig):
339+
result = process_metrics(result, metrics)
340+
else:
331341
logging.error(
332342
"kernel configuration {key} was skipped silently due to compile or runtime failure",
333343
key,
@@ -341,29 +351,34 @@ def run(self, parameter_space, tuning_options) -> List[Optional[dict]]:
341351
# add configuration to cache
342352
store_cache(key, result, tuning_options.cachefile, tuning_options.cache)
343353

344-
total_time = 1000 * (perf_counter() - self.start_time)
345-
self.start_time = perf_counter()
354+
# Store the result into the output array
355+
results[key2index[key]] = result
346356

347-
strategy_time = self.last_strategy_time
348-
self.last_strategy_time = 0
357+
# Fix duplicate entries. Duplicate entires do not get benchmark timings
358+
# as otherwise we would count them multiple times in the total
359+
for i, j in duplicate_entries:
360+
if results[j]:
361+
results[i] = disable_benchmark_timings(results[j])
349362

350-
runner_time = total_time - strategy_time
351-
framework_time = max(runner_time * len(self.workers) - total_worker_time, 0)
363+
# Count the number of valid results
364+
num_valid_results = sum(bool(r) for r in results)
352365

353-
num_valid_results = sum(bool(r) for r in results) # Count the number of valid results
366+
# If there are valid results, set timings
367+
if num_valid_results > 0:
368+
total_time = 1000 * (perf_counter() - self.start_time)
369+
self.start_time = perf_counter()
354370

355-
# Post-process all the results
356-
for result in results:
357-
# Skip missing results
358-
if not result:
359-
continue
371+
strategy_time = self.last_strategy_time
372+
self.last_strategy_time = 0
360373

361-
# Amortize the time over all the results
362-
result["strategy_time"] = strategy_time / num_valid_results
363-
result["framework_time"] = framework_time / num_valid_results
374+
runner_time = total_time - strategy_time
375+
framework_time = max(runner_time * len(self.workers) - total_worker_time, 0)
364376

365-
# only compute metrics on configs that have not errored
366-
if not isinstance(result.get(objective), ErrorConfig):
367-
result = process_metrics(result, metrics)
377+
# Post-process all the results
378+
for result in results:
379+
# Amortize the time over all the results
380+
if result:
381+
result["strategy_time"] = strategy_time / num_valid_results
382+
result["framework_time"] = framework_time / num_valid_results
368383

369384
return results

kernel_tuner/runners/sequential.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def run(self, parameter_space, tuning_options):
6464
logging.debug("sequential runner started for " + self.kernel_options.kernel_name)
6565

6666
results = []
67-
68-
# self.last_strategy_time is set by cost_func
69-
strategy_time_per_config = self.last_strategy_time / len(parameter_space) if len(parameter_space) > 0 else 0
67+
total_worker_time = 0
7068

7169
# iterate over parameter space
7270
for element in parameter_space:
@@ -101,15 +99,22 @@ def run(self, parameter_space, tuning_options):
10199
self.kernel_source, self.gpu_args, params, self.kernel_options, tuning_options
102100
)
103101

102+
# Collect total time spent by worker
103+
worker_time += (
104+
result["compile_time"] + result["verification_time"] + result["benchmark_time"]
105+
)
106+
104107
params.update(result)
105108

106-
if tuning_options.objective in result and isinstance(result[tuning_options.objective], ErrorConfig):
109+
if isinstance(result.get(tuning_options.objective), ErrorConfig):
107110
logging.debug("kernel configuration was skipped silently due to compile or runtime failure")
108111

109112
# only compute metrics on configs that have not errored
110-
if tuning_options.metrics and not isinstance(params.get(tuning_options.objective), ErrorConfig):
113+
if not isinstance(params.get(tuning_options.objective), ErrorConfig):
111114
params = process_metrics(params, tuning_options.metrics)
112115

116+
params["timestamp"] = str(datetime.now(timezone.utc))
117+
113118
if result:
114119
# print configuration to the console
115120
print_config_output(tuning_options.tune_params, params, self.quiet, tuning_options.metrics, self.units)
@@ -120,20 +125,23 @@ def run(self, parameter_space, tuning_options):
120125
# all visited configurations are added to results to provide a trace for optimization strategies
121126
results.append(params)
122127

128+
num_valid_results = sum(bool(r) for r in results) # Count the number of valid results
129+
130+
if num_valid_results > 0:
123131
# get the framework time by estimating based on other times
124-
total_time = 1000 * (perf_counter() - self.start_time) - warmup_time
132+
total_time = 1000 * (perf_counter() - self.start_time)
125133
self.start_time = perf_counter()
126134

127-
params["strategy_time"] = strategy_time_per_config
128-
params["framework_time"] = max(
129-
total_time
130-
- (
131-
params["compile_time"]
132-
+ params["verification_time"]
133-
+ params["benchmark_time"]
134-
),
135-
0,
136-
)
137-
params["timestamp"] = str(datetime.now(timezone.utc))
135+
strategy_time = self.last_strategy_time
136+
self.last_strategy_time = 0
137+
138+
framework_time = max(total_time - strategy_time - worker_time, 0)
139+
140+
# Post-process all the results
141+
for result in results:
142+
# Amortize the time over all the results
143+
if result:
144+
result["strategy_time"] = strategy_time / num_valid_results
145+
result["framework_time"] = framework_time / num_valid_results
138146

139147
return results

kernel_tuner/runners/simulation.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob
5757
self.total_simulated_time = 0
5858
self.last_strategy_start_time = self.start_time
5959
self.last_strategy_time = 0
60+
self.visited_results = set()
6061
self.units = {}
6162

6263
def get_device_info(self):
@@ -106,10 +107,18 @@ def run(self, parameter_space, tuning_options):
106107
if tuning_options.metrics and not isinstance(result.get(tuning_options.objective), util.ErrorConfig):
107108
result = util.process_metrics(result, tuning_options.metrics)
108109

109-
# configuration is evaluated for the first time, print to the console
110-
util.print_config_output(
111-
tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units
112-
)
110+
# Simulate behavior of sequential runner that when a configuration is
111+
# served from the cache by the sequential runner, the compile_time,
112+
# verification_time, and benchmark_time are set to 0.
113+
# This step is only performed in the simulation runner when a configuration
114+
# is served from the cache beyond the first timel. That is, when the
115+
# configuration is already counted towards the unique_results.
116+
if key in self.visited_results:
117+
result = util.disable_benchmark_timings(result)
118+
else:
119+
# configuration is evaluated for the first time, print to the console
120+
util.print_config_output(tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units)
121+
self.visited_results.add(key)
113122

114123
# Everything but the strategy time and framework time are simulated,
115124
result["strategy_time"] = strategy_time_per_config

kernel_tuner/strategies/common.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _normalize_and_validate_config(self, x, check_restrictions=True):
129129

130130
def _run_configs(self, xs, check_restrictions=True):
131131
""" Takes a list of Euclidian coordinates and evaluates the configurations at those points. """
132-
self.runner.last_strategy_time = 1000 * (perf_counter() - self.runner.last_strategy_start_time)
132+
self.runner.last_strategy_time += 1000 * (perf_counter() - self.runner.last_strategy_start_time)
133133
self.runner.start_time = perf_counter() # start framework time
134134

135135
# error value to return for numeric optimizers that need a numerical value
@@ -138,68 +138,43 @@ def _run_configs(self, xs, check_restrictions=True):
138138
# check if max_fevals is reached or time limit is exceeded
139139
self.tuning_options.budget.raise_exception_if_done()
140140

141+
batch_indices = [] # Where to store result in `final_results`
141142
batch_configs = [] # The configs to run
142-
batch_keys = [] # The keys of the configs to run
143-
pending_indices_by_key = dict() # Maps key => where to store result in `final_results`
144143
final_results = [] # List returned to the user
145-
legal_indices = [] # Indices in `final_results` that are legal
146144

147-
# Loop over all configurations. For each configurations there are four cases:
148-
# 1. The configuration is invalid, we can skip it
149-
# 2. The configuration is in `unique_results`, we can get it from there
150-
# 3. The configuration is in `pending_indices_by_key`, it is duplicate in `xs`
151-
# 4. The configuration must be evaluated by the runner.
145+
# Loop over all configurations.
152146
for index, x in enumerate(xs):
153147
config, is_legal = self._normalize_and_validate_config(x, check_restrictions=check_restrictions)
154148
logging.debug("normalize config: %s -> %s (legal: %s)", str(x), str(config), is_legal)
155149
key = ",".join([str(i) for i in config])
156150

157-
# 1. Not legal, just return `InvalidConfig`
151+
# Not legal, just return `InvalidConfig`
158152
if not is_legal:
159153
result = dict(zip(self.searchspace.tune_params.keys(), config))
160154
result[self.objective] = util.InvalidConfig()
161155
final_results.append(result)
162156

163-
# 2. Attempt to retrieve from `unique_results`
164-
elif key in self.unique_results:
165-
result = dict(self.unique_results[key])
166-
legal_indices.append(index)
167-
final_results.append(result)
168-
169-
# 3. We have already seen this config in the current batch
170-
elif key in pending_indices_by_key:
171-
pending_indices_by_key[key].append(index)
172-
final_results.append(None)
173-
174-
# 4. A new config, we must evaluate this
157+
# Legal config, we must evaluate this
175158
else:
176-
batch_keys.append(key)
159+
batch_indices.append(index)
177160
batch_configs.append(config)
178-
pending_indices_by_key[key] = [index]
179161
final_results.append(None)
180162

181163
# compile and benchmark the batch
182164
batch_results = self.runner.run(batch_configs, self.tuning_options)
183165

184-
for key, result in zip(batch_keys, batch_results):
166+
for index, config, result in zip(batch_indices, batch_configs, batch_results):
185167
# Skip. Result is missing because the runner has exhausted the budget
186168
if result is None:
187169
continue
188170

189171
# set in the results array
190-
for index in pending_indices_by_key[key]:
191-
legal_indices.append(index)
192-
final_results[index] = result
193-
194-
# Disable the timings. Only the first result must get these.
195-
result = util.disable_benchmark_timings(result)
172+
final_results[index] = result
196173

197174
# Put result in `unique_results`
198-
self.unique_results[key] = result
199-
200-
# Only things in `legal_indices` are valid results
201-
for index in sorted(legal_indices):
202-
self.results.append(final_results[index])
175+
key = ",".join([str(i) for i in config])
176+
self.unique_results.setdefault(key, result)
177+
self.results.append(result)
203178

204179
# upon returning from this function control will be given back to the strategy, so reset the start time
205180
self.runner.last_strategy_start_time = perf_counter()

0 commit comments

Comments
 (0)