5
5
import os
6
6
import sys
7
7
from copy import deepcopy
8
- import math
9
8
10
9
import codecs
11
10
import numpy as np
17
16
logging .basicConfig (stream = sys .stdout , level = logging .DEBUG )
18
17
logger = logging
19
18
19
+ logging .getLogger ("numba" ).setLevel (logging .WARNING )
20
+ logging .getLogger ("markdown_it" ).setLevel (logging .WARNING )
21
+ logging .getLogger ("urllib3" ).setLevel (logging .WARNING )
22
+ logging .getLogger ("matplotlib" ).setLevel (logging .WARNING )
23
+
24
+ import matplotlib .pylab as plt
25
+
20
26
21
27
def load_checkpoint (checkpoint_path , model , optimizer = None , load_opt = 1 ):
22
28
assert os .path .isfile (checkpoint_path )
@@ -125,8 +131,6 @@ def plot_spectrogram_to_numpy(spectrogram):
125
131
MATPLOTLIB_FLAG = True
126
132
mpl_logger = logging .getLogger ("matplotlib" )
127
133
mpl_logger .setLevel (logging .WARNING )
128
- import matplotlib .pylab as plt
129
- import numpy as np
130
134
131
135
fig , ax = plt .subplots (figsize = (10 , 2 ))
132
136
im = ax .imshow (spectrogram , aspect = "auto" , origin = "lower" , interpolation = "none" )
@@ -136,8 +140,12 @@ def plot_spectrogram_to_numpy(spectrogram):
136
140
plt .tight_layout ()
137
141
138
142
fig .canvas .draw ()
139
- data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
140
- data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
143
+ try :
144
+ data = np .array (fig .canvas .renderer .buffer_rgba (), dtype = np .uint8 )
145
+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (4 ,))[:, :, :3 ] # 只取前三个通道(RGB)
146
+ except :
147
+ data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
148
+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
141
149
plt .close ()
142
150
return data
143
151
@@ -151,8 +159,6 @@ def plot_alignment_to_numpy(alignment, info=None):
151
159
MATPLOTLIB_FLAG = True
152
160
mpl_logger = logging .getLogger ("matplotlib" )
153
161
mpl_logger .setLevel (logging .WARNING )
154
- import matplotlib .pylab as plt
155
- import numpy as np
156
162
157
163
fig , ax = plt .subplots (figsize = (6 , 4 ))
158
164
im = ax .imshow (
@@ -167,8 +173,12 @@ def plot_alignment_to_numpy(alignment, info=None):
167
173
plt .tight_layout ()
168
174
169
175
fig .canvas .draw ()
170
- data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
171
- data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
176
+ try :
177
+ data = np .array (fig .canvas .renderer .buffer_rgba (), dtype = np .uint8 )
178
+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (4 ,))[:, :, :3 ] # 只取前三个通道(RGB)
179
+ except :
180
+ data = np .fromstring (fig .canvas .tostring_rgb (), dtype = np .uint8 , sep = "" )
181
+ data = data .reshape (fig .canvas .get_width_height ()[::- 1 ] + (3 ,))
172
182
plt .close ()
173
183
return data
174
184
0 commit comments