Skip to content

Commit 3542fc7

Browse files
committed
FIX potential problem related to #69
1 parent 5a7207f commit 3542fc7

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

autosklearn/ensemble_builder.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def main(self):
124124
# over time!
125125
old_dir_ensemble_list_mtimes = dir_ensemble_list_mtimes
126126
dir_ensemble_list_mtimes = []
127+
# The ensemble dir can contain non-model files. We filter them and
128+
# use the following list instead
129+
dir_ensemble_model_files = []
127130

128131
for dir_ensemble_file in dir_ensemble_list:
129132
if dir_ensemble_file.endswith("/"):
@@ -132,18 +135,19 @@ def main(self):
132135
self.logger.warning('Error loading file (not .npy): %s', dir_ensemble_file)
133136
continue
134137

138+
dir_ensemble_model_files.append(dir_ensemble_file)
135139
basename = os.path.basename(dir_ensemble_file)
136140
dir_ensemble_file = os.path.join(dir_ensemble, basename)
137141
mtime = os.path.getmtime(dir_ensemble_file)
138142
dir_ensemble_list_mtimes.append(mtime)
139143

140-
if len(dir_ensemble_list) == 0:
144+
if len(dir_ensemble_model_files) == 0:
141145
self.logger.debug('Directories are empty')
142146
time.sleep(2)
143147
used_time = watch.wall_elapsed('ensemble_builder')
144148
continue
145149

146-
if len(dir_ensemble_list) <= current_num_models and \
150+
if len(dir_ensemble_model_files) <= current_num_models and \
147151
old_dir_ensemble_list_mtimes == dir_ensemble_list_mtimes:
148152
self.logger.debug('Nothing has changed since the last time')
149153
time.sleep(2)
@@ -169,7 +173,7 @@ def main(self):
169173
model_names_to_scores = dict()
170174

171175
model_idx = 0
172-
for model_name in dir_ensemble_list:
176+
for model_name in dir_ensemble_model_files:
173177
if model_name.endswith("/"):
174178
model_name = model_name[:-1]
175179
basename = os.path.basename(model_name)
@@ -254,7 +258,7 @@ def main(self):
254258

255259
indices_to_model_names = dict()
256260
indices_to_run_num = dict()
257-
for i, model_name in enumerate(dir_ensemble_list):
261+
for i, model_name in enumerate(dir_ensemble_model_files):
258262
match = model_and_automl_re.search(model_name)
259263
automl_seed = int(match.group(1))
260264
num_run = int(match.group(2))
@@ -265,7 +269,8 @@ def main(self):
265269

266270
try:
267271
all_predictions_train, all_predictions_valid, all_predictions_test =\
268-
self.get_all_predictions(dir_ensemble, dir_ensemble_list,
272+
self.get_all_predictions(dir_ensemble,
273+
dir_ensemble_model_files,
269274
dir_valid, dir_valid_list,
270275
dir_test, dir_test_list,
271276
include_num_runs,
@@ -314,7 +319,7 @@ def main(self):
314319

315320
# Set this variable here to avoid re-running the ensemble builder
316321
# every two seconds in case the ensemble did not change
317-
current_num_models = len(dir_ensemble_list)
322+
current_num_models = len(dir_ensemble_model_files)
318323

319324
ensemble_predictions = ensemble.predict(all_predictions_train)
320325
if sys.version_info[0] == 2:
@@ -342,7 +347,7 @@ def main(self):
342347
backend.save_ensemble(ensemble, index_run, self.seed)
343348

344349
# Save predictions for valid and test data set
345-
if len(dir_valid_list) == len(dir_ensemble_list):
350+
if len(dir_valid_list) == len(dir_ensemble_model_files):
346351
all_predictions_valid = np.array(all_predictions_valid)
347352
ensemble_predictions_valid = ensemble.predict(all_predictions_valid)
348353
if self.task_type == BINARY_CLASSIFICATION:
@@ -379,11 +384,11 @@ def main(self):
379384
else:
380385
self.logger.info('Could not find as many validation set predictions (%d)'
381386
'as ensemble predictions (%d)!.',
382-
len(dir_valid_list), len(dir_ensemble_list))
387+
len(dir_valid_list), len(dir_ensemble_model_files))
383388

384389
del all_predictions_valid
385390

386-
if len(dir_test_list) == len(dir_ensemble_list):
391+
if len(dir_test_list) == len(dir_ensemble_model_files):
387392
all_predictions_test = np.array(all_predictions_test)
388393
ensemble_predictions_test = ensemble.predict(all_predictions_test)
389394
if self.task_type == BINARY_CLASSIFICATION:
@@ -420,11 +425,11 @@ def main(self):
420425
else:
421426
self.logger.info('Could not find as many test set predictions (%d) as '
422427
'ensemble predictions (%d)!',
423-
len(dir_test_list), len(dir_ensemble_list))
428+
len(dir_test_list), len(dir_ensemble_model_files))
424429

425430
del all_predictions_test
426431

427-
current_num_models = len(dir_ensemble_list)
432+
current_num_models = len(dir_ensemble_model_files)
428433
watch.stop_task('index_run' + str(index_run))
429434
time_iter = watch.get_wall_dur('index_run' + str(index_run))
430435
used_time = watch.wall_elapsed('ensemble_builder')

0 commit comments

Comments
 (0)