Skip to content

Commit 859af2c

Browse files
[feat] 音频自动转换MIDI
1. 支持普通音频导入,通过本地模型,自动转换MIDI 2. 支持自动下载转录需要的模型(不用开代理下载 3. 显示下载模型进度 4. 显示转录文件进度,预估时间 5. 当前转录只完成了CPU版本
1 parent 3ca34db commit 859af2c

File tree

10 files changed

+1613
-5
lines changed

10 files changed

+1613
-5
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .inference import PianoTranscription
2+
from .config import sample_rate
3+
from .utilities import load_audio
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
sample_rate = 16000
2+
classes_num = 88 # Number of notes of piano
3+
begin_note = 21 # MIDI note of A0, the lowest note of a piano.
4+
segment_seconds = 10. # Training segment duration
5+
hop_seconds = 1.
6+
frames_per_second = 100
7+
velocity_scale = 128
Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
import os
2+
import numpy as np
3+
import time
4+
import librosa
5+
from pathlib import Path
6+
import urllib.request
7+
import torch
8+
9+
from .utilities import (create_folder, get_filename, RegressionPostProcessor, write_events_to_midi)
10+
from .models import Regress_onset_offset_frame_velocity_CRNN, Note_pedal
11+
from .pytorch_utils import move_data_to_device, forward
12+
from . import config
13+
14+
15+
16+
def download_with_progress(url, filename, progress_callback=None):
17+
def hook(count, block_size, total_size):
18+
if total_size > 0 and progress_callback:
19+
downloaded = count * block_size
20+
percent = min(downloaded / total_size, 1.0)
21+
progress_callback(downloaded, total_size, percent)
22+
urllib.request.urlretrieve(url, filename, reporthook=hook)
23+
24+
25+
class PianoTranscription(object):
26+
def __init__old(self, model_type='Note_pedal', checkpoint_path=None, segment_samples=16000*10, device=torch.device('cuda')):
27+
"""Class for transcribing piano solo recording.
28+
29+
Args:
30+
model_type: str
31+
checkpoint_path: str
32+
segment_samples: int
33+
device: 'cuda' | 'cpu'
34+
"""
35+
if not checkpoint_path:
36+
checkpoint_path='{}/piano_transcription_inference_data/note_F1=0.9677_pedal_F1=0.9186.pth'.format(os.getcwd())
37+
print('Checkpoint path: {}'.format(checkpoint_path))
38+
39+
# if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 1.6e8:
40+
# create_folder(os.path.dirname(checkpoint_path))
41+
# print('Total size: ~165 MB')
42+
# zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
43+
# os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
44+
if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 1.6e8:
45+
create_folder(os.path.dirname(checkpoint_path))
46+
print('Total size: ~165 MB')
47+
zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
48+
49+
try:
50+
print('正在下载模型...')
51+
urllib.request.urlretrieve(zenodo_path, checkpoint_path)
52+
print('下载完成!')
53+
except Exception as e:
54+
print(f'下载失败: {e}')
55+
print(f'请手动下载: {zenodo_path}')
56+
print(f'并保存到指定路径: {checkpoint_path}')
57+
print('Using {} for inference.'.format(device))
58+
59+
self.segment_samples = segment_samples
60+
self.frames_per_second = config.frames_per_second
61+
self.classes_num = config.classes_num
62+
self.onset_threshold = 0.3
63+
self.offset_threshod = 0.3
64+
self.frame_threshold = 0.1
65+
self.pedal_offset_threshold = 0.2
66+
67+
# Build model
68+
Model = eval(model_type)
69+
self.model = Model(frames_per_second=self.frames_per_second,
70+
classes_num=self.classes_num)
71+
72+
# Load model
73+
checkpoint = torch.load(checkpoint_path, map_location=device)
74+
self.model.load_state_dict(checkpoint['model'], strict=False)
75+
76+
# Parallel
77+
if 'cuda' in str(device):
78+
self.model.to(device)
79+
print('GPU number: {}'.format(torch.cuda.device_count()))
80+
self.model = torch.nn.DataParallel(self.model)
81+
else:
82+
print('Using CPU.')
83+
84+
def __init__(self, model_type='Note_pedal', checkpoint_path=None, segment_samples=16000*10, device=torch.device('cuda'), gui_callback=None):
85+
"""Class for transcribing piano solo recording.
86+
87+
Args:
88+
model_type: str
89+
checkpoint_path: str
90+
segment_samples: int
91+
device: 'cuda' | 'cpu'
92+
"""
93+
if not checkpoint_path:
94+
# checkpoint_path = os.path.join(os.getcwd(), 'piano_transcription_inference_data', 'note_F1=0.9677_pedal_F1=0.9186.pth')
95+
checkpoint_path = os.path.join(os.getcwd(), 'models', 'note_F13D0.9186.pth')
96+
97+
# zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
98+
download_path = 'https://mirror-huggingface.nuist666.top/note_F13D0.9186.pth'
99+
100+
if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 1.6e8:
101+
create_folder(os.path.dirname(checkpoint_path))
102+
if gui_callback:
103+
gui_callback("正在下载模型...")
104+
try:
105+
download_with_progress(
106+
download_path, checkpoint_path,
107+
progress_callback=lambda d, t, p: gui_callback(
108+
f"下载模型: {p * 100:.1f}% ({d / 1e6:.1f}/{t / 1e6:.1f} MB)")
109+
)
110+
if gui_callback:
111+
gui_callback("模型下载完成!")
112+
except Exception as e:
113+
if gui_callback:
114+
gui_callback(f"下载失败: {e}\n请手动下载到 {checkpoint_path}")
115+
raise e
116+
117+
if gui_callback:
118+
gui_callback("正在加载模型...")
119+
120+
print('Using {} for inference.'.format(device))
121+
122+
self.segment_samples = segment_samples
123+
self.frames_per_second = config.frames_per_second
124+
self.classes_num = config.classes_num
125+
self.onset_threshold = 0.3
126+
self.offset_threshod = 0.3
127+
self.frame_threshold = 0.1
128+
self.pedal_offset_threshold = 0.2
129+
130+
# Build model
131+
Model = eval(model_type)
132+
self.model = Model(frames_per_second=self.frames_per_second,
133+
classes_num=self.classes_num)
134+
135+
# Load model
136+
checkpoint = torch.load(checkpoint_path, map_location=device)
137+
self.model.load_state_dict(checkpoint['model'], strict=False)
138+
139+
# Parallel
140+
if 'cuda' in str(device):
141+
self.model.to(device)
142+
print('GPU number: {}'.format(torch.cuda.device_count()))
143+
self.model = torch.nn.DataParallel(self.model)
144+
else:
145+
print('Using CPU.')
146+
147+
def transcribe(self, audio, midi_path, gui_callback=None):
148+
"""Transcribe an audio recording.
149+
150+
Args:
151+
audio: (audio_samples,)
152+
midi_path: str, path to write out the transcribed MIDI.
153+
154+
Returns:
155+
transcribed_dict, dict: {'output_dict':, ..., 'est_note_events': ...}
156+
157+
"""
158+
audio = audio[None, :] # (1, audio_samples)
159+
160+
# Pad audio to be evenly divided by segment_samples
161+
audio_len = audio.shape[1]
162+
pad_len = int(np.ceil(audio_len / self.segment_samples))\
163+
* self.segment_samples - audio_len
164+
165+
audio = np.concatenate((audio, np.zeros((1, pad_len))), axis=1)
166+
167+
# Enframe to segments
168+
segments = self.enframe(audio, self.segment_samples)
169+
"""(N, segment_samples)"""
170+
171+
# Forward
172+
# output_dict = forward(self.model, segments, batch_size=1)
173+
output_dict = forward(self.model, segments, batch_size=1, progress_callback=gui_callback)
174+
175+
"""{'reg_onset_output': (N, segment_frames, classes_num), ...}"""
176+
177+
# Deframe to original length
178+
for key in output_dict.keys():
179+
output_dict[key] = self.deframe(output_dict[key])[0 : audio_len]
180+
"""output_dict: {
181+
'reg_onset_output': (N, segment_frames, classes_num),
182+
'reg_offset_output': (N, segment_frames, classes_num),
183+
'frame_output': (N, segment_frames, classes_num),
184+
'velocity_output': (N, segment_frames, classes_num)}"""
185+
186+
# Post processor
187+
post_processor = RegressionPostProcessor(self.frames_per_second,
188+
classes_num=self.classes_num, onset_threshold=self.onset_threshold,
189+
offset_threshold=self.offset_threshod,
190+
frame_threshold=self.frame_threshold,
191+
pedal_offset_threshold=self.pedal_offset_threshold)
192+
193+
# Post process output_dict to MIDI events
194+
(est_note_events, est_pedal_events) = \
195+
post_processor.output_dict_to_midi_events(output_dict)
196+
197+
# Write MIDI events to file
198+
if midi_path:
199+
write_events_to_midi(start_time=0, note_events=est_note_events,
200+
pedal_events=est_pedal_events, midi_path=midi_path)
201+
print('Write out to {}'.format(midi_path))
202+
203+
transcribed_dict = {
204+
'output_dict': output_dict,
205+
'est_note_events': est_note_events,
206+
'est_pedal_events': est_pedal_events}
207+
208+
return transcribed_dict
209+
210+
def enframe(self, x, segment_samples):
211+
"""Enframe long sequence to short segments.
212+
213+
Args:
214+
x: (1, audio_samples)
215+
segment_samples: int
216+
217+
Returns:
218+
batch: (N, segment_samples)
219+
"""
220+
assert x.shape[1] % segment_samples == 0
221+
batch = []
222+
223+
pointer = 0
224+
while pointer + segment_samples <= x.shape[1]:
225+
batch.append(x[:, pointer : pointer + segment_samples])
226+
pointer += segment_samples // 2
227+
228+
batch = np.concatenate(batch, axis=0)
229+
return batch
230+
231+
def deframe(self, x):
232+
"""Deframe predicted segments to original sequence.
233+
234+
Args:
235+
x: (N, segment_frames, classes_num)
236+
237+
Returns:
238+
y: (audio_frames, classes_num)
239+
"""
240+
if x.shape[0] == 1:
241+
return x[0]
242+
243+
else:
244+
x = x[:, 0 : -1, :]
245+
"""Remove an extra frame in the end of each segment caused by the
246+
'center=True' argument when calculating spectrogram."""
247+
(N, segment_samples, classes_num) = x.shape
248+
assert segment_samples % 4 == 0
249+
250+
y = []
251+
y.append(x[0, 0 : int(segment_samples * 0.75)])
252+
for i in range(1, N - 1):
253+
y.append(x[i, int(segment_samples * 0.25) : int(segment_samples * 0.75)])
254+
y.append(x[-1, int(segment_samples * 0.25) :])
255+
y = np.concatenate(y, axis=0)
256+
return y

0 commit comments

Comments
 (0)