-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[MNT] Fix test warnings in scripting part #3503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e9be151
94a9572
70d11c9
e7e1806
b6680f9
a4cca7d
7a0f2fe
bda648d
8e9a536
8b7a395
a1542f2
0ef1210
5733d8d
bc99d0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,6 @@ | ||
| import warnings | ||
| from unittest.mock import patch | ||
|
|
||
| import numpy as np | ||
| from scipy import stats | ||
| import sklearn.metrics as skl_metrics | ||
|
|
@@ -56,7 +59,8 @@ def compute_distances(self, x1, x2=None): | |
| distances *= -2 | ||
| distances += xx | ||
| distances += yy | ||
| np.maximum(distances, 0, out=distances) | ||
| with np.errstate(invalid="ignore"): # Nans are fixed below | ||
| np.maximum(distances, 0, out=distances) | ||
| if x2 is None: | ||
| distances.flat[::distances.shape[0] + 1] = 0.0 | ||
| fixer = _distance.fix_euclidean_rows_normalized if self.normalize \ | ||
|
|
@@ -111,7 +115,8 @@ def compute_distances(self, x1, x2=None): | |
| distances *= -2 | ||
| distances += xx | ||
| distances += xx.T | ||
| np.maximum(distances, 0, out=distances) | ||
| with np.errstate(invalid="ignore"): # Nans are fixed below | ||
| np.maximum(distances, 0, out=distances) | ||
| distances.flat[::distances.shape[0] + 1] = 0.0 | ||
|
|
||
| fixer = _distance.fix_euclidean_cols_normalized if self.normalize \ | ||
|
|
@@ -153,11 +158,24 @@ def fit_cols(self, attributes, x, n_vals): | |
| Return `EuclideanColumnsModel` with stored means and variances | ||
| for normalization and imputation. | ||
| """ | ||
| def nowarn(msg, cat, *args, **kwargs): | ||
| if cat is RuntimeWarning and ( | ||
| msg == "Mean of empty slice" | ||
| or msg == "Degrees of freedom <= 0 for slice"): | ||
| if self.normalize: | ||
| raise ValueError("some columns have no defined values") | ||
| else: | ||
| orig_warn(msg, cat, *args, **kwargs) | ||
|
|
||
| self.check_no_discrete(n_vals) | ||
| means = np.nanmean(x, axis=0) | ||
| vars = np.nanvar(x, axis=0) | ||
| if self.normalize and (np.isnan(vars).any() or not vars.all()): | ||
| raise ValueError("some columns are constant or have no values") | ||
| # catch_warnings resets the registry for "once", while avoiding this | ||
| # warning would be annoying and slow, hence patching | ||
| orig_warn = warnings.warn | ||
| with patch("warnings.warn", new=nowarn): | ||
| means = np.nanmean(x, axis=0) | ||
| vars = np.nanvar(x, axis=0) | ||
| if self.normalize and not vars.all(): | ||
| raise ValueError("some columns are constant") | ||
| return EuclideanColumnsModel( | ||
| attributes, self.impute, self.normalize, means, vars) | ||
|
|
||
|
|
@@ -277,8 +295,12 @@ def fit_cols(self, attributes, x, n_vals): | |
| for normalization and imputation. | ||
| """ | ||
| self.check_no_discrete(n_vals) | ||
| medians = np.nanmedian(x, axis=0) | ||
| mads = np.nanmedian(np.abs(x - medians), axis=0) | ||
| if x.size == 0: | ||
| medians = np.zeros(len(x)) | ||
| mads = np.zeros(len(x)) | ||
| else: | ||
| medians = np.nanmedian(x, axis=0) | ||
| mads = np.nanmedian(np.abs(x - medians), axis=0) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The warning here was "mean of empty slice", which appeared because |
||
| if self.normalize and (np.isnan(mads).any() or not mads.all()): | ||
| raise ValueError( | ||
| "some columns have zero absolute distance from median, " | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
| from sklearn.metrics import silhouette_score, adjusted_mutual_info_score, silhouette_samples | ||
|
|
||
|
|
@@ -35,17 +37,22 @@ class ClusteringScore(Score): | |
|
|
||
| def from_predicted(self, results, score_function): | ||
| # Clustering scores from labels | ||
| if self.considers_actual: | ||
| return np.fromiter( | ||
| (score_function(results.actual.flatten(), predicted.flatten()) | ||
| for predicted in results.predicted), | ||
| dtype=np.float64, count=len(results.predicted)) | ||
| # Clustering scores from data only | ||
| else: | ||
| return np.fromiter( | ||
| (score_function(results.data.X, predicted.flatten()) | ||
| for predicted in results.predicted), | ||
| dtype=np.float64, count=len(results.predicted)) | ||
| # This warning filter can be removed in scikit 0.22 | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings( | ||
| "ignore", "The behavior of AMI will change in version 0\.22.*") | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept this |
||
| if self.considers_actual: | ||
| return np.fromiter( | ||
| (score_function(results.actual.flatten(), | ||
| predicted.flatten()) | ||
| for predicted in results.predicted), | ||
| dtype=np.float64, count=len(results.predicted)) | ||
| # Clustering scores from data only | ||
| else: | ||
| return np.fromiter( | ||
| (score_function(results.data.X, predicted.flatten()) | ||
| for predicted in results.predicted), | ||
| dtype=np.float64, count=len(results.predicted)) | ||
|
|
||
|
|
||
| class Silhouette(ClusteringScore): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x[~mask]unnecessarily copies the array, and I wouldn't like to do it column by column. This patching is a non-idiomatic substitute forcatch_warnings.