Skip to content

Commit a057121

Browse files
committed
Refactor function selection and test
1 parent bf73d48 commit a057121

File tree

2 files changed

+6
-11
lines changed

2 files changed

+6
-11
lines changed

lib/matplotlib/collections.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -545,14 +545,11 @@ def set_offsets(self, offsets):
545545
offsets = np.asanyarray(offsets)
546546
if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else.
547547
offsets = offsets[None, :]
548-
if isinstance(offsets, np.ma.MaskedArray):
549-
self._offsets = np.ma.column_stack(
550-
(np.asanyarray(self.convert_xunits(offsets[:, 0]), float),
551-
np.asanyarray(self.convert_yunits(offsets[:, 1]), float)))
552-
else:
553-
self._offsets = np.column_stack(
554-
(np.asanyarray(self.convert_xunits(offsets[:, 0]), float),
555-
np.asanyarray(self.convert_yunits(offsets[:, 1]), float)))
548+
cstack = (np.ma.column_stack if isinstance(offsets, np.ma.MaskedArray)
549+
else np.column_stack)
550+
self._offsets = cstack(
551+
(np.asanyarray(self.convert_xunits(offsets[:, 0]), float),
552+
np.asanyarray(self.convert_yunits(offsets[:, 1]), float)))
556553
self.stale = True
557554

558555
def get_offsets(self):

lib/matplotlib/tests/test_collections.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,15 +1158,14 @@ def test_masked_set_offsets(fig_ref, fig_test):
11581158

11591159
ax_test = fig_test.add_subplot()
11601160
scat = ax_test.scatter(x, y)
1161-
x += 1
11621161
scat.set_offsets(np.ma.column_stack([x, y]))
11631162
ax_test.set_xticks([])
11641163
ax_test.set_yticks([])
11651164
ax_test.set_xlim(0, 7)
11661165
ax_test.set_ylim(0, 6)
11671166

11681167
ax_ref = fig_ref.add_subplot()
1169-
ax_ref.scatter([2, 3, 6], [1, 2, 5])
1168+
ax_ref.scatter([1, 2, 5], [1, 2, 5])
11701169
ax_ref.set_xticks([])
11711170
ax_ref.set_yticks([])
11721171
ax_ref.set_xlim(0, 7)
@@ -1180,7 +1179,6 @@ def test_check_offsets_dtype():
11801179

11811180
fig, ax = plt.subplots()
11821181
scat = ax.scatter(x, y)
1183-
x += 1
11841182
masked_offsets = np.ma.column_stack([x, y])
11851183
scat.set_offsets(masked_offsets)
11861184
assert isinstance(scat.get_offsets(), type(masked_offsets))

0 commit comments

Comments
 (0)