Skip to content

Commit 5ce2495

Browse files
committed
Implemented abstract base class for custom optimization algorithm strategies
1 parent 83ab888 commit 5ce2495

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

kernel_tuner/strategies/wrapper.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,48 @@
11
"""Wrapper intended for user-defined custom optimization methods"""
22

3+
from abc import ABC, abstractmethod
4+
35
from kernel_tuner import util
46
from kernel_tuner.searchspace import Searchspace
57
from kernel_tuner.strategies.common import CostFunc
68

79

10+
class OptAlg(ABC):
11+
"""Base class for user-defined optimization algorithms."""
12+
13+
def __init__(self):
14+
self.costfunc_kwargs = {"scaling": True, "snap": True}
15+
16+
@abstractmethod
17+
def __call__(self, func: CostFunc, searchspace: Searchspace, budget_spent_fraction: float) -> tuple[tuple, float]:
18+
"""_summary_
19+
20+
Args:
21+
func (CostFunc): Cost function to be optimized.
22+
searchspace (Searchspace): Search space containing the parameters to be optimized.
23+
budget_spent_fraction (float): Fraction of the budget that has already been spent.
24+
25+
Returns:
26+
tuple[tuple, float]: tuple of the best parameters and the corresponding cost value
27+
"""
28+
pass
29+
30+
831
class OptAlgWrapper:
932
"""Wrapper class for user-defined optimization algorithms"""
1033

11-
def __init__(self, optimizer, scaling=True):
12-
self.optimizer = optimizer
13-
self.scaling = scaling
14-
34+
def __init__(self, optimizer: OptAlg):
35+
self.optimizer: OptAlg = optimizer
1536

1637
def tune(self, searchspace: Searchspace, runner, tuning_options):
17-
cost_func = CostFunc(searchspace, tuning_options, runner, scaling=self.scaling)
38+
cost_func = CostFunc(searchspace, tuning_options, runner, **self.optimizer.costfunc_kwargs)
1839

19-
if self.scaling:
40+
if self.optimizer.costfunc_kwargs.get('scaling', True):
2041
# Initialize costfunc for scaling
2142
cost_func.get_bounds_x0_eps()
2243

2344
try:
24-
self.optimizer(cost_func)
45+
self.optimizer(cost_func, searchspace)
2546
except util.StopCriterionReached as e:
2647
if tuning_options.verbose:
2748
print(e)

0 commit comments

Comments
 (0)