Skip to content

Commit dbcb89d

Browse files
committed
Completed merge with updated custom_strategies
2 parents 72b615b + a4a69ae commit dbcb89d

File tree

6 files changed

+85
-57
lines changed

6 files changed

+85
-57
lines changed

kernel_tuner/backends/hip.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
"bool": ctypes.c_bool,
2020
"int8": ctypes.c_int8,
2121
"int16": ctypes.c_int16,
22-
"float16": ctypes.c_int16,
2322
"int32": ctypes.c_int32,
2423
"int64": ctypes.c_int64,
2524
"uint8": ctypes.c_uint8,
@@ -40,7 +39,9 @@ def hip_check(call_result):
4039
if len(result) == 1:
4140
result = result[0]
4241
if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess:
43-
raise RuntimeError(str(err), hip.hipGetLastError())
42+
_, error_name = hip.hipGetErrorName(err)
43+
_, error_str = hip.hipGetErrorString(err)
44+
raise RuntimeError(f"{error_name}: {error_str}")
4445
return result
4546

4647

@@ -120,25 +121,29 @@ def ready_argument_list(self, arguments):
120121

121122
# Handle numpy arrays
122123
if isinstance(arg, np.ndarray):
123-
if dtype_str in dtype_map.keys():
124-
# Allocate device memory
125-
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
124+
# Allocate device memory
125+
device_ptr = hip_check(hip.hipMalloc(arg.nbytes))
126126

127-
# Copy data to device using hipMemcpy
128-
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
127+
# Copy data to device using hipMemcpy
128+
hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice))
129129

130-
prepared_args.append(device_ptr)
131-
else:
132-
raise TypeError(f"Unknown dtype {dtype_str} for ndarray")
130+
prepared_args.append(device_ptr)
133131

134132
# Handle numpy scalar types
135133
elif isinstance(arg, np.generic):
136134
# Convert numpy scalar to corresponding ctypes
137-
ctype_arg = dtype_map[dtype_str](arg)
138-
prepared_args.append(ctype_arg)
135+
if dtype_str in dtype_map:
136+
ctype_arg = dtype_map[dtype_str](arg)
137+
prepared_args.append(ctype_arg)
138+
# 16-bit float is not supported, view it as uint16
139+
elif dtype_str in ("float16", "bfloat16"):
140+
ctype_arg = ctypes.c_uint16(arg.view(np.uint16))
141+
prepared_args.append(ctype_arg)
142+
else:
143+
raise ValueError(f"Invalid argument type {dtype_str}: {arg}")
139144

140145
else:
141-
raise ValueError(f"Invalid argument type {type(arg)}, {arg}")
146+
raise ValueError(f"Invalid argument type {type(arg)}: {arg}")
142147

143148
return prepared_args
144149

kernel_tuner/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def check_kernel_output(
509509
# run the kernel
510510
check = self.run_kernel(func, gpu_args, instance)
511511
if not check:
512-
# runtime failure occured that should be ignored, skip correctness check
512+
# runtime failure occurred that should be ignored, skip correctness check
513513
return
514514

515515
# retrieve gpu results to host memory
@@ -908,7 +908,7 @@ def split_argument_list(argument_list):
908908
match = re.match(regex, arg, re.S)
909909
if not match:
910910
raise ValueError("error parsing templated kernel argument list")
911-
type_list.append(re.sub(r"\s+", " ", match.group(1).strip(), re.S))
911+
type_list.append(re.sub(r"\s+", " ", match.group(1).strip(), flags=re.S))
912912
name_list.append(match.group(2).strip())
913913
return type_list, name_list
914914

kernel_tuner/strategies/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
tuning_options["max_fevals"] if "max_fevals" in tuning_options else np.inf, searchspace.size
101101
)
102102
self.results = []
103+
self.budget_spent_fraction = 0.0
103104

104105
# if enabled, encode non-numeric parameter values as a numeric value
105106
if self.encode_non_numeric:
@@ -127,7 +128,7 @@ def __call__(self, x, check_restrictions=True):
127128
logging.debug("x: %s", str(x))
128129

129130
# check if max_fevals is reached or time limit is exceeded
130-
util.check_stop_criterion(self.tuning_options)
131+
self.budget_spent_fraction = util.check_stop_criterion(self.tuning_options)
131132

132133
# snap values in x to nearest actual value for each parameter, unscale x if needed
133134
if self.snap:

kernel_tuner/strategies/wrapper.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,37 @@
11
"""Wrapper intended for user-defined custom optimization methods"""
22

3+
from abc import ABC, abstractmethod
4+
35
from kernel_tuner import util
46
from kernel_tuner.searchspace import Searchspace
57
from kernel_tuner.strategies.common import CostFunc
68

79

10+
class OptAlg(ABC):
11+
"""Base class for user-defined optimization algorithms."""
12+
13+
def __init__(self):
14+
self.costfunc_kwargs = {"scaling": True, "snap": True}
15+
16+
@abstractmethod
17+
def __call__(self, func: CostFunc, searchspace: Searchspace) -> tuple[tuple, float]:
18+
"""Optimize the black box function `func` within the given `searchspace`.
19+
20+
Args:
21+
func (CostFunc): Cost function to be optimized. Has a property `budget_spent_fraction` that indicates how much of the budget has been spent.
22+
searchspace (Searchspace): Search space containing the parameters to be optimized.
23+
24+
Returns:
25+
tuple[tuple, float]: tuple of the best parameters and the corresponding cost value
26+
"""
27+
pass
28+
29+
830
class OptAlgWrapper:
931
"""Wrapper class for user-defined optimization algorithms"""
1032

11-
def __init__(self, optimizer):
12-
self.optimizer = optimizer
13-
33+
def __init__(self, optimizer: OptAlg):
34+
self.optimizer: OptAlg = optimizer
1435

1536
def tune(self, searchspace: Searchspace, runner, tuning_options):
1637
cost_func = CostFunc(searchspace, tuning_options, runner, **self.optimizer.costfunc_kwargs)

kernel_tuner/util.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,28 @@ def check_argument_list(kernel_name, kernel_string, args):
190190
warnings.warn(errors[0], UserWarning)
191191

192192

193-
def check_stop_criterion(to):
194-
"""Checks if max_fevals is reached or time limit is exceeded."""
195-
if "max_fevals" in to and len(to.unique_results) >= to.max_fevals:
196-
raise StopCriterionReached(f"max_fevals reached ({len(to.unique_results)} >= {to.max_fevals})")
197-
if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3) + to.startup_time) > to.time_limit):
198-
raise StopCriterionReached("time limit exceeded")
193+
def check_stop_criterion(to: dict) -> float:
194+
"""Check if the stop criterion is reached.
195+
196+
Args:
197+
to (dict): tuning options.
198+
199+
Raises:
200+
StopCriterionReached: if the max_fevals is reached or time limit is exceeded.
201+
202+
Returns:
203+
float: fraction of budget spent.
204+
"""
205+
if "max_fevals" in to:
206+
if len(to.unique_results) >= to.max_fevals:
207+
raise StopCriterionReached(f"max_fevals ({to.max_fevals}) reached")
208+
return len(to.unique_results) / to.max_fevals
209+
if "time_limit" in to:
210+
time_spent = (time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3) + to.startup_time
211+
if time_spent > to.time_limit:
212+
raise StopCriterionReached("time limit exceeded")
213+
return time_spent / to.time_limit
214+
199215

200216

201217
def check_tune_params_list(tune_params, observers, simulation_mode=False):

test/test_custom_optimizer.py

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
import numpy as np
55

6-
class HybridDELocalRefinement:
6+
from kernel_tuner.strategies.wrapper import OptAlg
7+
8+
class HybridDELocalRefinement(OptAlg):
79
"""
810
A two-phase differential evolution with local refinement, intended for BBOB-type
911
black box optimization problems in [-5,5]^dim.
@@ -12,21 +14,14 @@ class HybridDELocalRefinement:
1214
exploration and local exploitation under a strict function evaluation budget.
1315
"""
1416

15-
def __init__(self, budget, dim):
16-
"""
17-
Initialize the optimizer with:
18-
- budget: total number of function evaluations allowed.
19-
- dim: dimensionality of the search space.
20-
"""
21-
self.budget = budget
22-
self.dim = dim
17+
def __init__(self):
18+
super().__init__()
2319
# You can adjust these hyperparameters based on experimentation/tuning:
24-
self.population_size = min(50, 10 * dim) # Caps for extremely large dim
2520
self.F = 0.8 # Differential weight
2621
self.CR = 0.9 # Crossover probability
2722
self.local_search_freq = 10 # Local refinement frequency in generations
2823

29-
def __call__(self, func):
24+
def __call__(self, func, searchspace):
3025
"""
3126
Optimize the black box function `func` in [-5,5]^dim, using
3227
at most self.budget function evaluations.
@@ -35,9 +30,8 @@ def __call__(self, func):
3530
best_params: np.ndarray representing the best parameters found
3631
best_value: float representing the best objective value found
3732
"""
38-
# Check if we have a non-positive budget
39-
if self.budget <= 0:
40-
raise ValueError("Budget must be a positive integer.")
33+
self.dim = searchspace.num_params
34+
self.population_size = round(min(min(50, 10 * self.dim), np.ceil(searchspace.size / 3))) # Caps for extremely large dim
4135

4236
# 1. Initialize population
4337
lower_bound, upper_bound = -5.0, 5.0
@@ -49,8 +43,6 @@ def __call__(self, func):
4943
for i in range(self.population_size):
5044
fitness[i] = func(pop[i])
5145
evaluations += 1
52-
if evaluations >= self.budget:
53-
break
5446

5547
# Track best solution
5648
best_idx = np.argmin(fitness)
@@ -59,7 +51,7 @@ def __call__(self, func):
5951

6052
# 2. Main evolutionary loop
6153
gen = 0
62-
while evaluations < self.budget:
54+
while func.budget_spent_fraction < 1.0 and evaluations < searchspace.size:
6355
gen += 1
6456
for i in range(self.population_size):
6557
# DE mutation: pick three distinct indices
@@ -78,7 +70,7 @@ def __call__(self, func):
7870
# Evaluate trial
7971
trial_fitness = func(trial)
8072
evaluations += 1
81-
if evaluations >= self.budget:
73+
if func.budget_spent_fraction > 1.0:
8274
# If out of budget, wrap up
8375
if trial_fitness < fitness[i]:
8476
pop[i] = trial
@@ -99,14 +91,11 @@ def __call__(self, func):
9991
best_params = trial.copy()
10092

10193
# Periodically refine best solution with a small local neighborhood search
102-
if gen % self.local_search_freq == 0 and evaluations < self.budget:
94+
if gen % self.local_search_freq == 0 and func.budget_spent_fraction < 1.0:
10395
best_params, best_value, evaluations = self._local_refinement(
10496
func, best_params, best_value, evaluations, lower_bound, upper_bound
10597
)
10698

107-
if evaluations >= self.budget:
108-
break
109-
11099
return best_params, best_value
111100

112101
def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub):
@@ -115,11 +104,10 @@ def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub):
115104
Uses a quick 'perturb-and-accept' approach in a shrinking neighborhood.
116105
"""
117106
# Neighborhood size shrinks as the budget is consumed
118-
frac_budget_used = evaluations / self.budget
119-
step_size = 0.2 * (1.0 - frac_budget_used)
107+
step_size = 0.2 * (1.0 - func.budget_spent_fraction)
120108

121109
for _ in range(5): # 5 refinements each time
122-
if evaluations >= self.budget:
110+
if func.budget_spent_fraction >= 1.0:
123111
break
124112
candidate = best_params + np.random.uniform(-step_size, step_size, self.dim)
125113
candidate = np.clip(candidate, lb, ub)
@@ -138,26 +126,23 @@ def _local_refinement(self, func, best_params, best_value, evaluations, lb, ub):
138126
import os
139127
from kernel_tuner import tune_kernel
140128
from kernel_tuner.strategies.wrapper import OptAlgWrapper
141-
cache_filename = os.path.dirname(
142-
143-
os.path.realpath(__file__)) + "/test_cache_file.json"
144129

145130
from .test_runners import env
146131

132+
cache_filename = os.path.dirname(os.path.realpath(__file__)) + "/test_cache_file.json"
147133

148134
def test_OptAlgWrapper(env):
149135
kernel_name, kernel_string, size, args, tune_params = env
150136

151137
# Instantiate LLaMAE optimization algorithm
152-
budget = int(15)
153-
dim = len(tune_params)
154-
optimizer = HybridDELocalRefinement(budget, dim)
138+
optimizer = HybridDELocalRefinement()
155139

156140
# Wrap the algorithm class in the OptAlgWrapper
157141
# for use in Kernel Tuner
158142
strategy = OptAlgWrapper(optimizer)
143+
strategy_options = { 'max_fevals': 15 }
159144

160145
# Call the tuner
161146
tune_kernel(kernel_name, kernel_string, size, args, tune_params,
162-
strategy=strategy, cache=cache_filename,
147+
strategy=strategy, strategy_options=strategy_options, cache=cache_filename,
163148
simulation_mode=True, verbose=True)

0 commit comments

Comments
 (0)