|
1 | 1 | """Wrapper intended for user-defined custom optimization methods"""
|
2 | 2 |
|
| 3 | +from abc import ABC, abstractmethod |
| 4 | + |
3 | 5 | from kernel_tuner import util
|
4 | 6 | from kernel_tuner.searchspace import Searchspace
|
5 | 7 | from kernel_tuner.strategies.common import CostFunc
|
6 | 8 |
|
7 | 9 |
|
| 10 | +class OptAlg(ABC): |
| 11 | + """Base class for user-defined optimization algorithms.""" |
| 12 | + |
| 13 | + def __init__(self): |
| 14 | + self.costfunc_kwargs = {"scaling": True, "snap": True} |
| 15 | + |
| 16 | + @abstractmethod |
| 17 | + def __call__(self, func: CostFunc, searchspace: Searchspace, budget_spent_fraction: float) -> tuple[tuple, float]: |
| 18 | + """_summary_ |
| 19 | +
|
| 20 | + Args: |
| 21 | + func (CostFunc): Cost function to be optimized. |
| 22 | + searchspace (Searchspace): Search space containing the parameters to be optimized. |
| 23 | + budget_spent_fraction (float): Fraction of the budget that has already been spent. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + tuple[tuple, float]: tuple of the best parameters and the corresponding cost value |
| 27 | + """ |
| 28 | + pass |
| 29 | + |
| 30 | + |
8 | 31 | class OptAlgWrapper:
|
9 | 32 | """Wrapper class for user-defined optimization algorithms"""
|
10 | 33 |
|
11 |
| - def __init__(self, optimizer, scaling=True): |
12 |
| - self.optimizer = optimizer |
13 |
| - self.scaling = scaling |
14 |
| - |
| 34 | + def __init__(self, optimizer: OptAlg): |
| 35 | + self.optimizer: OptAlg = optimizer |
15 | 36 |
|
16 | 37 | def tune(self, searchspace: Searchspace, runner, tuning_options):
|
17 |
| - cost_func = CostFunc(searchspace, tuning_options, runner, scaling=self.scaling) |
| 38 | + cost_func = CostFunc(searchspace, tuning_options, runner, **self.optimizer.costfunc_kwargs) |
18 | 39 |
|
19 |
| - if self.scaling: |
| 40 | + if self.optimizer.costfunc_kwargs.get('scaling', True): |
20 | 41 | # Initialize costfunc for scaling
|
21 | 42 | cost_func.get_bounds_x0_eps()
|
22 | 43 |
|
23 | 44 | try:
|
24 |
| - self.optimizer(cost_func) |
| 45 | + self.optimizer(cost_func, searchspace) |
25 | 46 | except util.StopCriterionReached as e:
|
26 | 47 | if tuning_options.verbose:
|
27 | 48 | print(e)
|
|
0 commit comments