diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 0f270c6f6e546..c689b62432847 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -4832,6 +4832,49 @@ def combineMult(self, other): """ return self.mul(other, fill_value=1.) + def where(self, cond, other): + """ + Return a DataFrame with the same shape as self and whose corresponding + entries are from self where cond is True and otherwise are from other. + + + Parameters + ---------- + cond: boolean DataFrame or array + other: scalar or DataFrame + + Returns + ------- + wh: DataFrame + """ + if isinstance(cond, np.ndarray): + if cond.shape != self.shape: + raise ValueError('Array onditional must be same shape as self') + cond = self._constructor(cond, index=self.index, columns=self.columns) + if cond.shape != self.shape: + cond = cond.reindex(self.index, columns=self.columns) + cond = cond.fillna(False) + + if isinstance(other, DataFrame): + _, other = self.align(other, join='left', fill_value=np.nan) + + rs = np.where(cond, self, other) + return self._constructor(rs, self.index, self.columns) + + def mask(self, cond): + """ + Returns copy of self whose values are replaced with nan if the + corresponding entry in cond is False + + Parameters + ---------- + cond: boolean DataFrame or array + + Returns + ------- + wh: DataFrame + """ + return self.where(cond, np.nan) _EMPTY_SERIES = Series([]) diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index c989e8c981231..3cea6c50f40f3 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -5063,6 +5063,34 @@ def test_align_int_fill_bug(self): expected = df2 - df2.mean() assert_frame_equal(result, expected) + def test_where(self): + df = DataFrame(np.random.randn(5, 3)) + cond = df > 0 + + other1 = df + 1 + rs = df.where(cond, other1) + for k, v in rs.iteritems(): + assert_series_equal(v, np.where(cond[k], df[k], other1[k])) + + other2 = (df + 1).values + rs = df.where(cond, other2) + for k, v in rs.iteritems(): + assert_series_equal(v, np.where(cond[k], df[k], other2[:, k])) + + other5 = np.nan + rs = df.where(cond, other5) + for k, v in rs.iteritems(): + assert_series_equal(v, np.where(cond[k], df[k], other5)) + + assert_frame_equal(rs, df.mask(cond)) + + err1 = (df + 1).values[0:2, :] + self.assertRaises(ValueError, df.where, cond, err1) + + err2 = cond.ix[:2, :].values + self.assertRaises(ValueError, df.where, err2, other1) + + #---------------------------------------------------------------------- # Transposing