Skip to content

Commit b871650

Browse files
committed
continue refactoring
1 parent 75854fb commit b871650

File tree

5 files changed

+41
-40
lines changed

5 files changed

+41
-40
lines changed

gradient_free_optimizers/optimizers/core_optimizer/core_optimizer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def __init__(
2121
):
2222
super().__init__()
2323

24+
self.random_seed = set_random_seed(nth_process, random_state)
25+
2426
self.conv = Converter(search_space)
2527
self.init = Initializer(self.conv, initialize)
2628

@@ -30,8 +32,6 @@ def __init__(
3032
self.nth_process = nth_process
3133
self.debug_log = debug_log
3234

33-
self.random_seed = set_random_seed(nth_process, random_state)
34-
3535
def random_iteration(func):
3636
def wrapper(self, *args, **kwargs):
3737
if self.rand_rest_p > random.uniform(0, 1):
@@ -62,10 +62,7 @@ def move_random(self):
6262

6363
@SearchTracker.track_new_pos
6464
def init_pos(self):
65-
print("\n 11111 self.init.init_positions_l", self.init.init_positions_l)
66-
print(" self.nth_trial", self.nth_trial)
67-
68-
init_pos = self.init.init_positions_l[self.nth_trial]
65+
init_pos = self.init.init_positions_l[self.n_init_total]
6966
return init_pos
7067

7168
def finish_initialization(self):

gradient_free_optimizers/optimizers/pop_opt/base_population_optimizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ def sort_pop_best_score(self):
3535
self.pop_sorted = [self.optimizers[i] for i in idx_sorted_ind]
3636

3737
def _create_population(self, Optimizer):
38-
diff_init = self.population - self.init.n_inits
38+
if isinstance(self.population, int):
39+
pop_size = self.population
40+
else:
41+
pop_size = len(self.population)
42+
diff_init = pop_size - self.init.n_inits
43+
3944
if diff_init > 0:
4045
self.init.add_n_random_init_pos(diff_init)
4146

@@ -55,8 +60,6 @@ def _create_population(self, Optimizer):
5560
else:
5661
population = self.population
5762

58-
print("\n init_positions_l \n", self.init.init_positions_l, "\n")
59-
6063
return population
6164

6265
@CoreOptimizer.track_new_score

gradient_free_optimizers/results_manager.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77

88
class ResultsManager:
9-
def __init__(self, conv):
9+
def __init__(self):
1010
super().__init__()
11-
self.conv = conv
11+
self.conv = None
1212

1313
self.results_list = []
1414

@@ -33,14 +33,11 @@ def _wrapper(pos):
3333
results_dict = self._obj_func_results(objective_function, para)
3434

3535
self.results_list.append({**results_dict, **para})
36-
print("score self.results_list", self.results_list)
3736

3837
return results_dict["score"]
3938

4039
return _wrapper
4140

4241
@property
4342
def search_data(self):
44-
print("self.results_list", self.results_list)
45-
4643
return pd.DataFrame(self.results_list)

gradient_free_optimizers/search.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88

99
from .progress_bar import ProgressBarLVL0, ProgressBarLVL1
1010
from .times_tracker import TimesTracker
11+
from .search_statistics import SearchStatistics
1112
from .memory import Memory
1213
from .print_info import print_info
1314
from .stop_run import StopRun
1415

1516
from .results_manager import ResultsManager
1617

1718

18-
class Search(TimesTracker):
19+
class Search(TimesTracker, SearchStatistics):
1920
def __init__(self):
2021
super().__init__()
2122

@@ -25,12 +26,11 @@ def __init__(self):
2526

2627
self.score_l = []
2728
self.pos_l = []
28-
self.nth_iter = 0
2929
self.random_seed = None
3030

3131
self.search_state = "init"
32-
self.n_init_total = 0
33-
self.n_iter_total = 0
32+
33+
self.results_mang = ResultsManager()
3434

3535
@TimesTracker.eval_time
3636
def _score(self, pos):
@@ -42,11 +42,9 @@ def _initialization(self, nth_iter):
4242
self.best_score = self.p_bar.score_best
4343

4444
init_pos = self.init_pos()
45-
print("\n init_pos ", init_pos)
4645

4746
score_new = self._score(init_pos)
4847
self.evaluate_init(score_new)
49-
print("\n score_new ", score_new)
5048

5149
self.pos_l.append(init_pos)
5250
self.score_l.append(score_new)
@@ -84,6 +82,7 @@ def _init_search(self):
8482
self.nth_process, self.n_iter, self.objective_function
8583
)
8684

85+
@SearchStatistics.init_stats
8786
def search(
8887
self,
8988
objective_function,
@@ -97,10 +96,7 @@ def search(
9796
):
9897
self.start_time = time.time()
9998

100-
self.results_mang = ResultsManager(self.conv)
101-
102-
self.n_init_search = 0
103-
self.n_iter_search = 0
99+
self.results_mang.conv = self.conv
104100

105101
if verbosity is False:
106102
verbosity = []
@@ -127,34 +123,19 @@ def search(
127123
else:
128124
self.score = self.results_mang.score(objective_function)
129125

130-
print(
131-
"\n search init_positions_l \n",
132-
self.init.init_positions_l,
133-
"\n",
134-
)
135-
136-
n_inits_norm = min(self.init.n_inits, n_iter)
137-
print("\n n_inits_norm", n_inits_norm)
126+
n_inits_norm = min((self.init.n_inits - self.n_init_total), n_iter)
138127

139128
# if self.search_state == "init":
140129
# loop to initialize N positions
141130
for nth_iter in range(n_inits_norm):
142-
print("\n init!")
143131
if self.stop.check(self.start_time, self.p_bar.score_best, self.score_l):
144132
break
145133
self._initialization(nth_iter)
146134

147-
print("pos_new_list", self.pos_new_list)
148-
149135
self.finish_initialization()
150136

151-
print("\n self.n_init_search", self.n_init_search)
152-
print("\n self.n_init_total", self.n_init_total)
153-
154137
# loop to do the iterations
155138
for nth_iter in range(self.n_init_search, n_iter):
156-
print("\n iter!")
157-
158139
if self.stop.check(self.start_time, self.p_bar.score_best, self.score_l):
159140
break
160141
self._iteration(nth_iter)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Author: Simon Blanke
2+
# Email: simon.blanke@yahoo.com
3+
# License: MIT License
4+
5+
6+
class SearchStatistics:
7+
def __init__(self):
8+
super().__init__()
9+
10+
self.nth_iter = 0
11+
12+
self.n_init_total = 0
13+
self.n_iter_total = 0
14+
15+
def init_stats(func):
16+
def wrapper(self, *args, **kwargs):
17+
self.n_init_search = 0
18+
self.n_iter_search = 0
19+
20+
res = func(self, *args, **kwargs)
21+
return res
22+
23+
return wrapper

0 commit comments

Comments
 (0)