Skip to content

Commit 23e7a90

Browse files
qi116Brian Qi
andauthored
Score Following RL (#195)
* Updated rl requirements.txt * Trying to change environment. Changes: Increased window to 15 when training. -0.5 reward when standing still to encourage movement. Tried randomizing start location for training * Changed columns_per_beat to 16, basically manually aligning so that agent can keep moving forwards. * Added vecnormalization stuff. Seems to work well. Changed cols_per_beat to 4 * Included changes for normalization. Might need to change tracking_window. --------- Co-authored-by: Brian Qi <brianqi@Brians-MacBook-Pro.local>
1 parent 9499b40 commit 23e7a90

File tree

5 files changed

+64
-18
lines changed

5 files changed

+64
-18
lines changed

reinforcement_learning/gymnasium_env/envs/score_following_env.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,11 @@ def midi_to_piano_roll(midi_path: str, fps: int = 20) -> np.ndarray:
5555

5656

5757
class ScoreFollowingEnv(gym.Env):
58-
def __init__(self, midi_path: str, audio_path: str, bpm: int, alignment: np.ndarray):
58+
def __init__(self, midi_path: str, audio_path: str, bpm: int, alignment: np.ndarray, training=False):
5959
super(ScoreFollowingEnv, self).__init__()
6060

6161
self.alignment = alignment
62-
62+
self.training = training
6363
# Define audio processing parameters
6464
sr = 22050 # Sample rate in Hz
6565
n_fft = 2048 # FFT window size
@@ -90,15 +90,18 @@ def __init__(self, midi_path: str, audio_path: str, bpm: int, alignment: np.ndar
9090

9191
# Define window sizes (in quarter notes)
9292
self.score_window_beats = 10 # Number of beats for score context
93-
columns_per_beat = 1 # Number of columns per beat in the piano roll
93+
self.columns_per_beat = 4 # Number of columns per beat in the piano roll
94+
columns_per_beat = self.columns_per_beat
9495
score_fps = calculate_piano_roll_fps(columns_per_beat, bpm) # Calculate fps based on BPM
9596

9697
# Get the piano roll representation of the MIDI file
9798
# This is the "world" the agent will be navigating
9899
self.piano_roll = midi_to_piano_roll(midi_path, fps=score_fps)
99100
self.size = self.piano_roll.shape[1]
100101

101-
self.tracking_window = 5 # max distance from target to agent before termination
102+
self.tracking_window = 15 if self.training else 5
103+
self.tracking_window *= columns_per_beat # Extend leniency because we grow note sizes?
104+
# max distance from target to agent before termination
102105

103106
# Define dimensions for our fixed-size representations
104107
# Score window length is a fixed number of beats
@@ -202,7 +205,7 @@ def update_target_location(self):
202205
target_index = np.where(note_onsets > live_time)[0]
203206
if target_index.size > 0: # if there are note onsets after the current time
204207
target_index = target_index[0] # get the first one
205-
self._target_location = beats[target_index] # get the corresponding beat
208+
self._target_location = beats[target_index] * self.columns_per_beat # get the corresponding beat
206209
else:
207210
# If no note onsets are found, set target_location to the end of the audio
208211
self._target_location = beats[-1]
@@ -226,9 +229,20 @@ def _get_obs(self):
226229
}
227230

228231
def _get_info(self):
229-
return {"distance": abs(self._agent_location - self._target_location)}
232+
return {"distance": abs(self._agent_location - self._target_location), "target": self._target_location}
230233

231234
def reset(self, seed=None):
235+
super().reset(seed=seed)
236+
237+
# Trying to change starting position during training because otherwise agent never moved.
238+
# if self.training:
239+
# self._agent_location = int(self.np_random.integers(0, self.size))#0
240+
# self._target_location = self._agent_location
241+
# while self._target_location == self._agent_location:
242+
# self._target_location = int(self.np_random.integers(0, self.size))
243+
# self.num_steps = int(self._agent_location)
244+
245+
# else:
232246
self._agent_location = 0
233247
self._target_location = 0
234248
self.num_steps = 0
@@ -243,18 +257,25 @@ def step(self, action):
243257
self._agent_location -= 1
244258
elif action == 1:
245259
self._agent_location += 1
246-
260+
247261
# Clip the agent's location to be within the valid range
248262
self._agent_location = np.clip(self._agent_location, 0, self.size - 1)
249263

250-
offtrack = abs(self._agent_location - self._target_location) > self.tracking_window
264+
offtrack = abs(self._agent_location - self._target_location) > self.tracking_window #
251265
end_of_score = self._agent_location >= self.size
252266
end_of_spectrogram = self.num_steps >= self.spectrogram.shape[1]
253-
terminated = offtrack or end_of_score or end_of_spectrogram
267+
terminated = end_of_score or end_of_spectrogram or offtrack
254268

255269
truncated = False
256270
tracking_error = self._agent_location - self._target_location
271+
257272
reward = 1 - abs(tracking_error) / self.tracking_window # Compute reward based on tracking error
273+
# reward = np.exp(-0.5 * (tracking_error / self.tracking_window)**2) # Gaussian curve
274+
275+
if action == 2 and tracking_error > 0:
276+
# reward -= (abs(tracking_error) / self.tracking_window) * 0.5
277+
reward -= 0.5 #try to discourage staying still
278+
258279
self.num_steps += 1 # Increment the number of steps
259280
self.update_target_location()
260281
observation = self._get_obs()
340 KB
Binary file not shown.

reinforcement_learning/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
gymnasium==1.1.1
22
imageio==2.37.0
33
librosa==0.11.0
4-
matplotlib==3.10.1
5-
numpy==2.2.4
4+
matplotlib==3.9.4
5+
numpy>=1.23
66
pretty_midi==0.2.10
77
pygame==2.1.3
88
stable_baselines3==2.6.0
Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,49 @@
11
from gymnasium_env.envs.score_following_env import ScoreFollowingEnv
22
import numpy as np
33
from stable_baselines3 import PPO
4+
from stable_baselines3.common.env_util import make_vec_env
5+
from stable_baselines3.common.vec_env import VecNormalize
46

57
alignment = [(i, 6 / 7 * i) for i in range(0, 64)]
68
alignment = np.array(alignment).T
79

8-
env = ScoreFollowingEnv(midi_path="ode_beg.mid", audio_path="ode_beg.mp3", bpm=70, alignment=alignment)
9-
model = PPO.load("ppo_score_following", env=env)
10+
env_kwargs={
11+
"midi_path": "ode_beg.mid",
12+
"audio_path": "ode_beg.mp3",
13+
"bpm": 70,
14+
"alignment": alignment,
15+
}
16+
vec_env = make_vec_env(lambda: ScoreFollowingEnv(**env_kwargs), n_envs=1)
17+
env = VecNormalize.load("ppo_score_following_env4", vec_env)
18+
# env = ScoreFollowingEnv(midi_path="ode_beg.mid", audio_path="ode_beg.mp3", bpm=70, alignment=alignment, training=False)
19+
20+
model = PPO.load("ppo_score_following4", env=env)
1021

1122
# Reset the environment
12-
obs, info = env.reset()
23+
obs = env.reset()
1324
terminated = False
1425

1526
# Get the initial agent location
1627
agent_location = obs["agent"][0]
1728

29+
total_reward = 0
1830
i = 0
1931
while not terminated:
2032
# Get the action from the model
2133
action, _ = model.predict(obs, deterministic=True)
2234

2335
# Take a step in the environment
24-
obs, reward, terminated, truncated, info = env.step(action)
36+
obs, reward, terminated, truncated = env.step(action)
2537
agent_location, score_window, spectrogram_window = obs["agent"][0], obs["score"], obs["spectrogram"]
2638
env.render(mode="human")
2739

2840
# Print the agent's location and reward
29-
print(f"Step {i}: Agent location: {agent_location}, Reward: {reward}, Info: {info}")
30-
i += 1
41+
current_agent_loc = obs["agent"][0][0] # Adjust indexing based on actual obs structure
42+
current_reward = reward[0]
43+
target_loc = env.get_attr('_target_location')[0]
44+
agent_loc= env.get_attr('_agent_location')[0]
45+
print(f"Step {i}: Action: {action[0]}, Agent loc: {agent_loc:.2f}, target loc: {target_loc:.2f}, Reward: {current_reward:.3f}")
46+
i += 1
47+
total_reward += reward
48+
49+
print(f'Total reward: {total_reward}')

reinforcement_learning/train_agent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from stable_baselines3 import PPO
33
from gymnasium_env.envs.score_following_env import ScoreFollowingEnv
44
from stable_baselines3.common.env_util import make_vec_env
5+
from stable_baselines3.common.vec_env import VecNormalize
56
from tqdm import tqdm
67

78

@@ -17,14 +18,19 @@
1718
"audio_path": "ode_beg.mp3",
1819
"bpm": 70,
1920
"alignment": alignment,
21+
"training": True
2022
},
2123
)
2224

25+
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=False, gamma=0.99)
26+
2327
# Create the PPO model using MultiInputPolicy to handle the Dict observation space.
2428
model = PPO("MultiInputPolicy", vec_env, verbose=1)
2529

2630
# Train the model for a specified number of timesteps.
2731
model.learn(total_timesteps=100_000, progress_bar=tqdm)
2832

2933
# Save the trained model.
30-
model.save("ppo_score_following")
34+
model.save("ppo_score_following4")
35+
vec_env.save("ppo_score_following_env4")
36+
vec_env.close()

0 commit comments

Comments
 (0)