55import os
66import sys
77from copy import deepcopy
8- import math
98
109import codecs
1110import numpy as np
1716logging .basicConfig (stream = sys .stdout , level = logging .DEBUG )
1817logger = logging
1918
19+ logging .getLogger ("numba" ).setLevel (logging .WARNING )
20+ logging .getLogger ("markdown_it" ).setLevel (logging .WARNING )
21+ logging .getLogger ("urllib3" ).setLevel (logging .WARNING )
22+ logging .getLogger ("matplotlib" ).setLevel (logging .WARNING )
23+
24+ import matplotlib .pylab as plt
25+
2026
2127def load_checkpoint (checkpoint_path , model , optimizer = None , load_opt = 1 ):
2228 assert os .path .isfile (checkpoint_path )
@@ -125,8 +131,6 @@ def plot_spectrogram_to_numpy(spectrogram):
125131 MATPLOTLIB_FLAG = True
126132 mpl_logger = logging .getLogger ("matplotlib" )
127133 mpl_logger .setLevel (logging .WARNING )
128- import matplotlib .pylab as plt
129- import numpy as np
130134
131135 fig , ax = plt .subplots (figsize = (10 , 2 ))
132136 im = ax .imshow (spectrogram , aspect = "auto" , origin = "lower" , interpolation = "none" )
@@ -136,8 +140,12 @@ def plot_spectrogram_to_numpy(spectrogram):
136140 plt .tight_layout ()
137141
138142 fig .canvas .draw ()
139- data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
140- data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
143+ try :
144+ data = np .array (fig .canvas .renderer .buffer_rgba (), dtype = np .uint8 )
145+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (4 ,))[:, :, :3 ] # 只取前三个通道(RGB)
146+ except :
147+ data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
148+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
141149 plt .close ()
142150 return data
143151
@@ -151,8 +159,6 @@ def plot_alignment_to_numpy(alignment, info=None):
151159 MATPLOTLIB_FLAG = True
152160 mpl_logger = logging .getLogger ("matplotlib" )
153161 mpl_logger .setLevel (logging .WARNING )
154- import matplotlib .pylab as plt
155- import numpy as np
156162
157163 fig , ax = plt .subplots (figsize = (6 , 4 ))
158164 im = ax .imshow (
@@ -167,8 +173,12 @@ def plot_alignment_to_numpy(alignment, info=None):
167173 plt .tight_layout ()
168174
169175 fig .canvas .draw ()
170- data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
171- data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
176+ try :
177+ data = np .array (fig .canvas .renderer .buffer_rgba (), dtype = np .uint8 )
178+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (4 ,))[:, :, :3 ] # 只取前三个通道(RGB)
179+ except :
180+ data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
181+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
172182 plt .close ()
173183 return data
174184
0 commit comments