@@ -193,24 +193,23 @@ def stats(X, weights=None, compute_variance=False):
193193 is_sparse = sp .issparse (X )
194194 weighted = weights is not None and X .dtype != object
195195
196- if weighted :
197- weights = np .c_ [weights ] / sum (weights )
196+ def weighted_mean ():
198197 if is_sparse :
199- w_X = X .multiply (sp .csr_matrix (weights ))
200- weighted_mean = np .asarray (w_X .sum (axis = 0 )).ravel ()
198+ w_X = X .multiply (sp .csr_matrix (np . c_ [ weights ] / sum ( weights ) ))
199+ return np .asarray (w_X .sum (axis = 0 )).ravel ()
201200 else :
202- weighted_mean = np .nansum (X * weights , axis = 0 )
201+ return np .nansum (X * np . c_ [ weights ] / sum ( weights ) , axis = 0 )
203202
204203 if X .size and is_numeric and not is_sparse :
205204 nans = np .isnan (X ).sum (axis = 0 )
206205 return np .column_stack ((
207206 np .nanmin (X , axis = 0 ),
208207 np .nanmax (X , axis = 0 ),
209- np .nanmean (X , axis = 0 ) if not weighted else weighted_mean ,
208+ np .nanmean (X , axis = 0 ) if not weighted else weighted_mean () ,
210209 np .nanvar (X , axis = 0 ) if compute_variance else np .zeros (X .shape [1 ]),
211210 nans ,
212211 X .shape [0 ] - nans ))
213- elif is_sparse :
212+ elif is_sparse and X . size :
214213 if compute_variance :
215214 raise NotImplementedError
216215
@@ -219,7 +218,7 @@ def stats(X, weights=None, compute_variance=False):
219218 return np .column_stack ((
220219 X .min (axis = 0 ).toarray ().ravel (),
221220 X .max (axis = 0 ).toarray ().ravel (),
222- np .asarray (X .mean (axis = 0 )).ravel () if not weighted else weighted_mean ,
221+ np .asarray (X .mean (axis = 0 )).ravel () if not weighted else weighted_mean () ,
223222 np .zeros (X .shape [1 ]), # variance not supported
224223 X .shape [0 ] - non_zero ,
225224 non_zero ))
0 commit comments