1+ import os
2+ import numpy as np
3+ import time
4+ import librosa
5+ from pathlib import Path
6+ import urllib .request
7+ import torch
8+
9+ from .utilities import (create_folder , get_filename , RegressionPostProcessor , write_events_to_midi )
10+ from .models import Regress_onset_offset_frame_velocity_CRNN , Note_pedal
11+ from .pytorch_utils import move_data_to_device , forward
12+ from . import config
13+
14+
15+
16+ def download_with_progress (url , filename , progress_callback = None ):
17+ def hook (count , block_size , total_size ):
18+ if total_size > 0 and progress_callback :
19+ downloaded = count * block_size
20+ percent = min (downloaded / total_size , 1.0 )
21+ progress_callback (downloaded , total_size , percent )
22+ urllib .request .urlretrieve (url , filename , reporthook = hook )
23+
24+
25+ class PianoTranscription (object ):
26+ def __init__old (self , model_type = 'Note_pedal' , checkpoint_path = None , segment_samples = 16000 * 10 , device = torch .device ('cuda' )):
27+ """Class for transcribing piano solo recording.
28+
29+ Args:
30+ model_type: str
31+ checkpoint_path: str
32+ segment_samples: int
33+ device: 'cuda' | 'cpu'
34+ """
35+ if not checkpoint_path :
36+ checkpoint_path = '{}/piano_transcription_inference_data/note_F1=0.9677_pedal_F1=0.9186.pth' .format (os .getcwd ())
37+ print ('Checkpoint path: {}' .format (checkpoint_path ))
38+
39+ # if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 1.6e8:
40+ # create_folder(os.path.dirname(checkpoint_path))
41+ # print('Total size: ~165 MB')
42+ # zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
43+ # os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
44+ if not os .path .exists (checkpoint_path ) or os .path .getsize (checkpoint_path ) < 1.6e8 :
45+ create_folder (os .path .dirname (checkpoint_path ))
46+ print ('Total size: ~165 MB' )
47+ zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
48+
49+ try :
50+ print ('正在下载模型...' )
51+ urllib .request .urlretrieve (zenodo_path , checkpoint_path )
52+ print ('下载完成!' )
53+ except Exception as e :
54+ print (f'下载失败: { e } ' )
55+ print (f'请手动下载: { zenodo_path } ' )
56+ print (f'并保存到指定路径: { checkpoint_path } ' )
57+ print ('Using {} for inference.' .format (device ))
58+
59+ self .segment_samples = segment_samples
60+ self .frames_per_second = config .frames_per_second
61+ self .classes_num = config .classes_num
62+ self .onset_threshold = 0.3
63+ self .offset_threshod = 0.3
64+ self .frame_threshold = 0.1
65+ self .pedal_offset_threshold = 0.2
66+
67+ # Build model
68+ Model = eval (model_type )
69+ self .model = Model (frames_per_second = self .frames_per_second ,
70+ classes_num = self .classes_num )
71+
72+ # Load model
73+ checkpoint = torch .load (checkpoint_path , map_location = device )
74+ self .model .load_state_dict (checkpoint ['model' ], strict = False )
75+
76+ # Parallel
77+ if 'cuda' in str (device ):
78+ self .model .to (device )
79+ print ('GPU number: {}' .format (torch .cuda .device_count ()))
80+ self .model = torch .nn .DataParallel (self .model )
81+ else :
82+ print ('Using CPU.' )
83+
84+ def __init__ (self , model_type = 'Note_pedal' , checkpoint_path = None , segment_samples = 16000 * 10 , device = torch .device ('cuda' ), gui_callback = None ):
85+ """Class for transcribing piano solo recording.
86+
87+ Args:
88+ model_type: str
89+ checkpoint_path: str
90+ segment_samples: int
91+ device: 'cuda' | 'cpu'
92+ """
93+ if not checkpoint_path :
94+ # checkpoint_path = os.path.join(os.getcwd(), 'piano_transcription_inference_data', 'note_F1=0.9677_pedal_F1=0.9186.pth')
95+ checkpoint_path = os .path .join (os .getcwd (), 'models' , 'note_F13D0.9186.pth' )
96+
97+ # zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
98+ download_path = 'https://mirror-huggingface.nuist666.top/note_F13D0.9186.pth'
99+
100+ if not os .path .exists (checkpoint_path ) or os .path .getsize (checkpoint_path ) < 1.6e8 :
101+ create_folder (os .path .dirname (checkpoint_path ))
102+ if gui_callback :
103+ gui_callback ("正在下载模型..." )
104+ try :
105+ download_with_progress (
106+ download_path , checkpoint_path ,
107+ progress_callback = lambda d , t , p : gui_callback (
108+ f"下载模型: { p * 100 :.1f} % ({ d / 1e6 :.1f} /{ t / 1e6 :.1f} MB)" )
109+ )
110+ if gui_callback :
111+ gui_callback ("模型下载完成!" )
112+ except Exception as e :
113+ if gui_callback :
114+ gui_callback (f"下载失败: { e } \n 请手动下载到 { checkpoint_path } " )
115+ raise e
116+
117+ if gui_callback :
118+ gui_callback ("正在加载模型..." )
119+
120+ print ('Using {} for inference.' .format (device ))
121+
122+ self .segment_samples = segment_samples
123+ self .frames_per_second = config .frames_per_second
124+ self .classes_num = config .classes_num
125+ self .onset_threshold = 0.3
126+ self .offset_threshod = 0.3
127+ self .frame_threshold = 0.1
128+ self .pedal_offset_threshold = 0.2
129+
130+ # Build model
131+ Model = eval (model_type )
132+ self .model = Model (frames_per_second = self .frames_per_second ,
133+ classes_num = self .classes_num )
134+
135+ # Load model
136+ checkpoint = torch .load (checkpoint_path , map_location = device )
137+ self .model .load_state_dict (checkpoint ['model' ], strict = False )
138+
139+ # Parallel
140+ if 'cuda' in str (device ):
141+ self .model .to (device )
142+ print ('GPU number: {}' .format (torch .cuda .device_count ()))
143+ self .model = torch .nn .DataParallel (self .model )
144+ else :
145+ print ('Using CPU.' )
146+
147+ def transcribe (self , audio , midi_path , gui_callback = None ):
148+ """Transcribe an audio recording.
149+
150+ Args:
151+ audio: (audio_samples,)
152+ midi_path: str, path to write out the transcribed MIDI.
153+
154+ Returns:
155+ transcribed_dict, dict: {'output_dict':, ..., 'est_note_events': ...}
156+
157+ """
158+ audio = audio [None , :] # (1, audio_samples)
159+
160+ # Pad audio to be evenly divided by segment_samples
161+ audio_len = audio .shape [1 ]
162+ pad_len = int (np .ceil (audio_len / self .segment_samples ))\
163+ * self .segment_samples - audio_len
164+
165+ audio = np .concatenate ((audio , np .zeros ((1 , pad_len ))), axis = 1 )
166+
167+ # Enframe to segments
168+ segments = self .enframe (audio , self .segment_samples )
169+ """(N, segment_samples)"""
170+
171+ # Forward
172+ # output_dict = forward(self.model, segments, batch_size=1)
173+ output_dict = forward (self .model , segments , batch_size = 1 , progress_callback = gui_callback )
174+
175+ """{'reg_onset_output': (N, segment_frames, classes_num), ...}"""
176+
177+ # Deframe to original length
178+ for key in output_dict .keys ():
179+ output_dict [key ] = self .deframe (output_dict [key ])[0 : audio_len ]
180+ """output_dict: {
181+ 'reg_onset_output': (N, segment_frames, classes_num),
182+ 'reg_offset_output': (N, segment_frames, classes_num),
183+ 'frame_output': (N, segment_frames, classes_num),
184+ 'velocity_output': (N, segment_frames, classes_num)}"""
185+
186+ # Post processor
187+ post_processor = RegressionPostProcessor (self .frames_per_second ,
188+ classes_num = self .classes_num , onset_threshold = self .onset_threshold ,
189+ offset_threshold = self .offset_threshod ,
190+ frame_threshold = self .frame_threshold ,
191+ pedal_offset_threshold = self .pedal_offset_threshold )
192+
193+ # Post process output_dict to MIDI events
194+ (est_note_events , est_pedal_events ) = \
195+ post_processor .output_dict_to_midi_events (output_dict )
196+
197+ # Write MIDI events to file
198+ if midi_path :
199+ write_events_to_midi (start_time = 0 , note_events = est_note_events ,
200+ pedal_events = est_pedal_events , midi_path = midi_path )
201+ print ('Write out to {}' .format (midi_path ))
202+
203+ transcribed_dict = {
204+ 'output_dict' : output_dict ,
205+ 'est_note_events' : est_note_events ,
206+ 'est_pedal_events' : est_pedal_events }
207+
208+ return transcribed_dict
209+
210+ def enframe (self , x , segment_samples ):
211+ """Enframe long sequence to short segments.
212+
213+ Args:
214+ x: (1, audio_samples)
215+ segment_samples: int
216+
217+ Returns:
218+ batch: (N, segment_samples)
219+ """
220+ assert x .shape [1 ] % segment_samples == 0
221+ batch = []
222+
223+ pointer = 0
224+ while pointer + segment_samples <= x .shape [1 ]:
225+ batch .append (x [:, pointer : pointer + segment_samples ])
226+ pointer += segment_samples // 2
227+
228+ batch = np .concatenate (batch , axis = 0 )
229+ return batch
230+
231+ def deframe (self , x ):
232+ """Deframe predicted segments to original sequence.
233+
234+ Args:
235+ x: (N, segment_frames, classes_num)
236+
237+ Returns:
238+ y: (audio_frames, classes_num)
239+ """
240+ if x .shape [0 ] == 1 :
241+ return x [0 ]
242+
243+ else :
244+ x = x [:, 0 : - 1 , :]
245+ """Remove an extra frame in the end of each segment caused by the
246+ 'center=True' argument when calculating spectrogram."""
247+ (N , segment_samples , classes_num ) = x .shape
248+ assert segment_samples % 4 == 0
249+
250+ y = []
251+ y .append (x [0 , 0 : int (segment_samples * 0.75 )])
252+ for i in range (1 , N - 1 ):
253+ y .append (x [i , int (segment_samples * 0.25 ) : int (segment_samples * 0.75 )])
254+ y .append (x [- 1 , int (segment_samples * 0.25 ) :])
255+ y = np .concatenate (y , axis = 0 )
256+ return y
0 commit comments