@@ -270,10 +270,18 @@ def plt_to_np(plt, close=True, axis=False):
270270 plt .axis ('off' )
271271 fig = plt .gcf ()
272272 fig .canvas .draw ()
273- data = np .frombuffer (fig .canvas .tostring_rgb (), dtype = np .uint8 )
274- w , h = fig .canvas .get_width_height ()
275273 try :
276- np_img = data .reshape ((int (h ), int (w ), - 1 ))
274+ w , h = fig .canvas .get_width_height ()
275+ if hasattr (fig .canvas , 'tostring_rgb' ):
276+ data = np .frombuffer (fig .canvas .tostring_rgb (), dtype = np .uint8 )
277+ np_img = data .reshape ((int (h ), int (w ), 3 ))
278+ elif hasattr (fig .canvas , 'tostring_argb' ):
279+ data = np .frombuffer (fig .canvas .tostring_argb (), dtype = np .uint8 )
280+ argb = data .reshape ((int (h ), int (w ), 4 ))
281+ np_img = argb [:, :, 1 :4 ]
282+ else :
283+ rgba = np .asarray (fig .canvas .buffer_rgba ())
284+ np_img = rgba [:, :, :3 ]
277285 except Exception as e :
278286 print (e )
279287 np_img = None
@@ -440,4 +448,4 @@ def image_resize(image, width=None, height=None, inter=None):
440448 resized = cv2 .resize (image , dim , interpolation = inter )
441449
442450 # return the resized image
443- return resized
451+ return resized
0 commit comments