@@ -458,15 +458,35 @@ def test_imshow():
458458 ax .set_ylim (0 , 3 )
459459
460460
461- @image_comparison ([ 'imshow' ], remove_text = True , style = 'mpl20' )
462- def test_imshow_10_10_1 ():
463- fig , ax = plt . subplots ()
461+ @check_figures_equal ( extensions = [ 'png' ] )
462+ def test_imshow_10_10_1 (fig_test , fig_ref ):
463+ # 10x10x1 should be the same as 10x10
464464 arr = np .arange (100 ).reshape ((10 , 10 , 1 ))
465+ ax = fig_ref .subplots ()
466+ ax .imshow (arr [:, :, 0 ], interpolation = "bilinear" , extent = (1 , 2 , 1 , 2 ))
467+ ax .set_xlim (0 , 3 )
468+ ax .set_ylim (0 , 3 )
469+
470+ ax = fig_test .subplots ()
465471 ax .imshow (arr , interpolation = "bilinear" , extent = (1 , 2 , 1 , 2 ))
466472 ax .set_xlim (0 , 3 )
467473 ax .set_ylim (0 , 3 )
468474
469475
476+ def test_imshow_10_10_2 ():
477+ fig , ax = plt .subplots ()
478+ arr = np .arange (200 ).reshape ((10 , 10 , 2 ))
479+ with pytest .raises (TypeError ):
480+ ax .imshow (arr )
481+
482+
483+ def test_imshow_10_10_5 ():
484+ fig , ax = plt .subplots ()
485+ arr = np .arange (500 ).reshape ((10 , 10 , 5 ))
486+ with pytest .raises (TypeError ):
487+ ax .imshow (arr )
488+
489+
470490@image_comparison (['no_interpolation_origin' ], remove_text = True )
471491def test_no_interpolation_origin ():
472492 fig , axs = plt .subplots (2 )
0 commit comments