Skip to content

Commit eb0b5cc

Browse files
authored
Support array-like mask in heatmaps (#3803)
* Support array-like mask in heatmaps * Nit
1 parent e0c2431 commit eb0b5cc

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

seaborn/matrix.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,14 @@ def _matrix_mask(data, mask):
6969
if mask is None:
7070
mask = np.zeros(data.shape, bool)
7171

72-
if isinstance(mask, np.ndarray):
72+
if isinstance(mask, pd.DataFrame):
73+
# For DataFrame masks, ensure that semantic labels match data
74+
if not mask.index.equals(data.index) \
75+
and mask.columns.equals(data.columns):
76+
err = "Mask must have the same index and columns as data."
77+
raise ValueError(err)
78+
elif hasattr(mask, "__array__"):
79+
mask = np.asarray(mask)
7380
# For array masks, ensure that shape matches data then convert
7481
if mask.shape != data.shape:
7582
raise ValueError("Mask must have the same shape as data.")
@@ -79,13 +86,6 @@ def _matrix_mask(data, mask):
7986
columns=data.columns,
8087
dtype=bool)
8188

82-
elif isinstance(mask, pd.DataFrame):
83-
# For DataFrame masks, ensure that semantic labels match data
84-
if not mask.index.equals(data.index) \
85-
and mask.columns.equals(data.columns):
86-
err = "Mask must have the same index and columns as data."
87-
raise ValueError(err)
88-
8989
# Add any cells with missing data to the mask
9090
# This works around an issue where `plt.pcolormesh` doesn't represent
9191
# missing data properly

tests/test_matrix.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,24 @@ def test_ndarray_input(self):
5656
assert p.xlabel == ""
5757
assert p.ylabel == ""
5858

59+
def test_array_like_input(self):
60+
class ArrayLike:
61+
def __init__(self, data):
62+
self.data = data
63+
64+
def __array__(self, dtype=None, copy=None):
65+
return np.asarray(self.data, dtype=dtype, copy=copy)
66+
67+
p = mat._HeatMapper(ArrayLike(self.x_norm), **self.default_kws)
68+
npt.assert_array_equal(p.plot_data, self.x_norm)
69+
pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm))
70+
71+
npt.assert_array_equal(p.xticklabels, np.arange(8))
72+
npt.assert_array_equal(p.yticklabels, np.arange(4))
73+
74+
assert p.xlabel == ""
75+
assert p.ylabel == ""
76+
5977
def test_df_input(self):
6078

6179
p = mat._HeatMapper(self.df_norm, **self.default_kws)

0 commit comments

Comments
 (0)