@@ -55,11 +55,11 @@ def midi_to_piano_roll(midi_path: str, fps: int = 20) -> np.ndarray:
5555
5656
5757class ScoreFollowingEnv (gym .Env ):
58- def __init__ (self , midi_path : str , audio_path : str , bpm : int , alignment : np .ndarray ):
58+ def __init__ (self , midi_path : str , audio_path : str , bpm : int , alignment : np .ndarray , training = False ):
5959 super (ScoreFollowingEnv , self ).__init__ ()
6060
6161 self .alignment = alignment
62-
62+ self . training = training
6363 # Define audio processing parameters
6464 sr = 22050 # Sample rate in Hz
6565 n_fft = 2048 # FFT window size
@@ -90,15 +90,18 @@ def __init__(self, midi_path: str, audio_path: str, bpm: int, alignment: np.ndar
9090
9191 # Define window sizes (in quarter notes)
9292 self .score_window_beats = 10 # Number of beats for score context
93- columns_per_beat = 1 # Number of columns per beat in the piano roll
93+ self .columns_per_beat = 4 # Number of columns per beat in the piano roll
94+ columns_per_beat = self .columns_per_beat
9495 score_fps = calculate_piano_roll_fps (columns_per_beat , bpm ) # Calculate fps based on BPM
9596
9697 # Get the piano roll representation of the MIDI file
9798 # This is the "world" the agent will be navigating
9899 self .piano_roll = midi_to_piano_roll (midi_path , fps = score_fps )
99100 self .size = self .piano_roll .shape [1 ]
100101
101- self .tracking_window = 5 # max distance from target to agent before termination
102+ self .tracking_window = 15 if self .training else 5
103+ self .tracking_window *= columns_per_beat # Extend leniency because we grow note sizes?
104+ # max distance from target to agent before termination
102105
103106 # Define dimensions for our fixed-size representations
104107 # Score window length is a fixed number of beats
@@ -202,7 +205,7 @@ def update_target_location(self):
202205 target_index = np .where (note_onsets > live_time )[0 ]
203206 if target_index .size > 0 : # if there are note onsets after the current time
204207 target_index = target_index [0 ] # get the first one
205- self ._target_location = beats [target_index ] # get the corresponding beat
208+ self ._target_location = beats [target_index ] * self . columns_per_beat # get the corresponding beat
206209 else :
207210 # If no note onsets are found, set target_location to the end of the audio
208211 self ._target_location = beats [- 1 ]
@@ -226,9 +229,20 @@ def _get_obs(self):
226229 }
227230
228231 def _get_info (self ):
229- return {"distance" : abs (self ._agent_location - self ._target_location )}
232+ return {"distance" : abs (self ._agent_location - self ._target_location ), "target" : self . _target_location }
230233
231234 def reset (self , seed = None ):
235+ super ().reset (seed = seed )
236+
237+ # Trying to change starting position during training because otherwise agent never moved.
238+ # if self.training:
239+ # self._agent_location = int(self.np_random.integers(0, self.size))#0
240+ # self._target_location = self._agent_location
241+ # while self._target_location == self._agent_location:
242+ # self._target_location = int(self.np_random.integers(0, self.size))
243+ # self.num_steps = int(self._agent_location)
244+
245+ # else:
232246 self ._agent_location = 0
233247 self ._target_location = 0
234248 self .num_steps = 0
@@ -243,18 +257,25 @@ def step(self, action):
243257 self ._agent_location -= 1
244258 elif action == 1 :
245259 self ._agent_location += 1
246-
260+
247261 # Clip the agent's location to be within the valid range
248262 self ._agent_location = np .clip (self ._agent_location , 0 , self .size - 1 )
249263
250- offtrack = abs (self ._agent_location - self ._target_location ) > self .tracking_window
264+ offtrack = abs (self ._agent_location - self ._target_location ) > self .tracking_window #
251265 end_of_score = self ._agent_location >= self .size
252266 end_of_spectrogram = self .num_steps >= self .spectrogram .shape [1 ]
253- terminated = offtrack or end_of_score or end_of_spectrogram
267+ terminated = end_of_score or end_of_spectrogram or offtrack
254268
255269 truncated = False
256270 tracking_error = self ._agent_location - self ._target_location
271+
257272 reward = 1 - abs (tracking_error ) / self .tracking_window # Compute reward based on tracking error
273+ # reward = np.exp(-0.5 * (tracking_error / self.tracking_window)**2) # Gaussian curve
274+
275+ if action == 2 and tracking_error > 0 :
276+ # reward -= (abs(tracking_error) / self.tracking_window) * 0.5
277+ reward -= 0.5 #try to discourage staying still
278+
258279 self .num_steps += 1 # Increment the number of steps
259280 self .update_target_location ()
260281 observation = self ._get_obs ()
0 commit comments