Skip to content

Commit 91017ed

Browse files
committed
MAINT fix #238, hide refit/predict warnings
1 parent 1c87889 commit 91017ed

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

autosklearn/automl.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import os
88
import unittest.mock
9+
import warnings
910

1011

1112
from ConfigSpace.io import pcs
@@ -432,6 +433,12 @@ def _fit(self, datamanager):
432433
return self
433434

434435
def refit(self, X, y):
436+
def send_warnings_to_log(message, category, filename, lineno,
437+
file=None):
438+
self._logger.debug('%s:%s: %s:%s' %
439+
(filename, lineno, category.__name__, message))
440+
return
441+
435442
if self._keep_models is not True:
436443
raise ValueError(
437444
"Predict can only be called if 'keep_models==True'")
@@ -451,7 +458,9 @@ def refit(self, X, y):
451458
# the ordering of the data.
452459
for i in range(10):
453460
try:
454-
model.fit(X.copy(), y.copy())
461+
with warnings.catch_warnings():
462+
warnings.showwarning = send_warnings_to_log
463+
model.fit(X.copy(), y.copy())
455464
break
456465
except ValueError:
457466
indices = list(range(X.shape[0]))
@@ -477,15 +486,23 @@ def predict(self, X):
477486
self.ensemble_ is None:
478487
self._load_models()
479488

489+
def send_warnings_to_log(message, category, filename, lineno,
490+
file=None):
491+
self._logger.debug('%s:%s: %s:%s' %
492+
(filename, lineno, category.__name__, message))
493+
return
494+
480495
all_predictions = []
481496
for identifier in self.ensemble_.get_model_identifiers():
482497
model = self.models_[identifier]
483498

484499
X_ = X.copy()
485-
if self._task in REGRESSION_TASKS:
486-
prediction = model.predict(X_)
487-
else:
488-
prediction = model.predict_proba(X_)
500+
with warnings.catch_warnings():
501+
warnings.showwarning = send_warnings_to_log
502+
if self._task in REGRESSION_TASKS:
503+
prediction = model.predict(X_)
504+
else:
505+
prediction = model.predict_proba(X_)
489506

490507
if len(prediction.shape) < 1 or len(X_.shape) < 1 or \
491508
X_.shape[0] < 1 or prediction.shape[0] != X_.shape[0]:

0 commit comments

Comments
 (0)