Skip to content

Commit 02c8199

Browse files
authored
Merge pull request matplotlib#15090 from kmader/patch-1
ENH: Coerce MxNx1 images into MxN images for imshow
2 parents 126de7f + 5d27567 commit 02c8199

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Imshow now coerces 3D arrays with depth 1 to 2D
2+
------------------------------------------------
3+
Starting from this version arrays of size MxNx1 will be coerced into MxN
4+
for displaying. This means commands like ``plt.imshow(np.random.rand(3, 3, 1))``
5+
will no longer return an error message that the image shape is invalid.

lib/matplotlib/image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,10 @@ def set_data(self, A):
684684
raise TypeError("Image data of dtype {} cannot be converted to "
685685
"float".format(self._A.dtype))
686686

687+
if (self._A.ndim == 3 and self._A.shape[-1] == 1):
688+
# If just one dimension assume scalar and apply colormap
689+
self._A = self._A[:, :, 0]
690+
687691
if not (self._A.ndim == 2
688692
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
689693
raise TypeError("Invalid shape {} for image data"

lib/matplotlib/tests/test_image.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,35 @@ def test_imshow():
467467
ax.set_ylim(0, 3)
468468

469469

470+
@check_figures_equal(extensions=['png'])
471+
def test_imshow_10_10_1(fig_test, fig_ref):
472+
# 10x10x1 should be the same as 10x10
473+
arr = np.arange(100).reshape((10, 10, 1))
474+
ax = fig_ref.subplots()
475+
ax.imshow(arr[:, :, 0], interpolation="bilinear", extent=(1, 2, 1, 2))
476+
ax.set_xlim(0, 3)
477+
ax.set_ylim(0, 3)
478+
479+
ax = fig_test.subplots()
480+
ax.imshow(arr, interpolation="bilinear", extent=(1, 2, 1, 2))
481+
ax.set_xlim(0, 3)
482+
ax.set_ylim(0, 3)
483+
484+
485+
def test_imshow_10_10_2():
486+
fig, ax = plt.subplots()
487+
arr = np.arange(200).reshape((10, 10, 2))
488+
with pytest.raises(TypeError):
489+
ax.imshow(arr)
490+
491+
492+
def test_imshow_10_10_5():
493+
fig, ax = plt.subplots()
494+
arr = np.arange(500).reshape((10, 10, 5))
495+
with pytest.raises(TypeError):
496+
ax.imshow(arr)
497+
498+
470499
@image_comparison(['no_interpolation_origin'], remove_text=True)
471500
def test_no_interpolation_origin():
472501
fig, axs = plt.subplots(2)

0 commit comments

Comments
 (0)