Skip to content

Commit 0c401f3

Browse files
committed
add parameter to search class
1 parent 190058b commit 0c401f3

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

src/gradient_free_optimizers/optimizer_search/powells_method.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010

1111
class PowellsMethod(_PowellsMethod, Search):
1212
"""
13-
A class implementing **pattern search** for the public API.
13+
A class implementing **Powell's conjugate direction method** for the public API.
1414
Inheriting from the `Search`-class to get the `search`-method and from
1515
the `PowellsMethod`-backend to get the underlying algorithm.
1616
17+
Powell's method performs sequential line searches along a set of directions,
18+
updating the directions after each complete cycle to form conjugate directions.
19+
This leads to faster convergence than simple coordinate descent.
20+
1721
Parameters
1822
----------
1923
search_space : dict[str, list]
@@ -32,12 +36,16 @@ class PowellsMethod(_PowellsMethod, Search):
3236
rand_rest_p : float
3337
The probability of a random iteration during the search process.
3438
epsilon : float
35-
The step-size for the climbing.
39+
The step-size for hill climbing line search.
3640
distribution : str
37-
The type of distribution to sample from.
41+
The type of distribution to sample from for hill climbing.
3842
n_neighbours : int
3943
The number of neighbours to sample and evaluate before moving to the best
4044
of those neighbours.
45+
iters_per_direction : int
46+
Number of evaluations per direction during line search.
47+
line_search : str
48+
Line search method: "grid" (default), "golden", or "hill_climb".
4149
"""
4250

4351
def __init__(
@@ -51,7 +59,11 @@ def __init__(
5159
random_state: int = None,
5260
rand_rest_p: float = 0,
5361
nth_process: int = None,
54-
iters_p_dim: int = 10,
62+
epsilon: float = 0.03,
63+
distribution: str = "normal",
64+
n_neighbours: int = 3,
65+
iters_per_direction: int = 10,
66+
line_search: Literal["grid", "golden", "hill_climb"] = "grid",
5567
):
5668
super().__init__(
5769
search_space=search_space,
@@ -60,5 +72,9 @@ def __init__(
6072
random_state=random_state,
6173
rand_rest_p=rand_rest_p,
6274
nth_process=nth_process,
63-
iters_p_dim=iters_p_dim,
75+
epsilon=epsilon,
76+
distribution=distribution,
77+
n_neighbours=n_neighbours,
78+
iters_per_direction=iters_per_direction,
79+
line_search=line_search,
6480
)

0 commit comments

Comments
 (0)