Skip to content

Commit fe11be3

Browse files
committed
fix(train): matplotlib deprecation (#103)
1 parent 89f7fa2 commit fe11be3

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

infer/lib/train/utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import sys
77
from copy import deepcopy
8-
import math
98

109
import codecs
1110
import numpy as np
@@ -17,6 +16,13 @@
1716
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
1817
logger = logging
1918

19+
logging.getLogger("numba").setLevel(logging.WARNING)
20+
logging.getLogger("markdown_it").setLevel(logging.WARNING)
21+
logging.getLogger("urllib3").setLevel(logging.WARNING)
22+
logging.getLogger("matplotlib").setLevel(logging.WARNING)
23+
24+
import matplotlib.pylab as plt
25+
2026

2127
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
2228
assert os.path.isfile(checkpoint_path)
@@ -125,8 +131,6 @@ def plot_spectrogram_to_numpy(spectrogram):
125131
MATPLOTLIB_FLAG = True
126132
mpl_logger = logging.getLogger("matplotlib")
127133
mpl_logger.setLevel(logging.WARNING)
128-
import matplotlib.pylab as plt
129-
import numpy as np
130134

131135
fig, ax = plt.subplots(figsize=(10, 2))
132136
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
@@ -136,8 +140,12 @@ def plot_spectrogram_to_numpy(spectrogram):
136140
plt.tight_layout()
137141

138142
fig.canvas.draw()
139-
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
140-
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
143+
try:
144+
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
145+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB)
146+
except:
147+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
148+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
141149
plt.close()
142150
return data
143151

@@ -151,8 +159,6 @@ def plot_alignment_to_numpy(alignment, info=None):
151159
MATPLOTLIB_FLAG = True
152160
mpl_logger = logging.getLogger("matplotlib")
153161
mpl_logger.setLevel(logging.WARNING)
154-
import matplotlib.pylab as plt
155-
import numpy as np
156162

157163
fig, ax = plt.subplots(figsize=(6, 4))
158164
im = ax.imshow(
@@ -167,8 +173,12 @@ def plot_alignment_to_numpy(alignment, info=None):
167173
plt.tight_layout()
168174

169175
fig.canvas.draw()
170-
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
171-
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
176+
try:
177+
data = np.array(fig.canvas.renderer.buffer_rgba(), dtype=np.uint8)
178+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3] # 只取前三个通道(RGB)
179+
except:
180+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
181+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
172182
plt.close()
173183
return data
174184

0 commit comments

Comments
 (0)