Skip to content

Commit c8ec0c4

Browse files
authored
add direct train script for MaxEntropyIRL (#14)
1 parent cb70bf6 commit c8ec0c4

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
#
2+
# This file is a refactored implementation of the Maximum Entropy IRL from:
3+
# https://github.com/reinforcement-learning-kr/lets-do-irl/tree/master/mountaincar/maxent
4+
# It is a class type implementation restructured for our use case.
5+
#
6+
7+
import gym
8+
import numpy as np
9+
import matplotlib.pyplot as plt
10+
11+
12+
class MaxEntropyIRL:
13+
def __init__(self, env, feature_matrix, one_feature, q_table, q_learning_rate, gamma, n_states, theta):
14+
self.env = env
15+
self.feature_matrix = feature_matrix
16+
self.one_feature = one_feature
17+
self.q_table = q_table
18+
self.q_learning_rate = q_learning_rate
19+
self.theta = theta
20+
self.gamma = gamma
21+
self.n_states = n_states
22+
23+
def get_feature_matrix(self):
24+
"""
25+
Returns the feature matrix.
26+
:return:
27+
"""
28+
return self.feature_matrix
29+
30+
def get_reward(self, n_states, state_idx):
31+
"""
32+
Returns the achieved reward.
33+
:param n_states:
34+
:param state_idx:
35+
:return:
36+
"""
37+
irl_rewards = self.feature_matrix.dot(self.theta).reshape((n_states,))
38+
return irl_rewards[state_idx]
39+
40+
def get_demonstrations(self):
41+
"""
42+
Parses the demonstrations and returns the demonstrations.
43+
:param one_feature:
44+
:return:
45+
"""
46+
env_low = self.env.observation_space.low
47+
env_high = self.env.observation_space.high
48+
env_distance = (env_high - env_low) / self.one_feature
49+
50+
raw_demo = np.load(file="src/irlwpython/expert_demo/expert_demo.npy")
51+
demonstrations = np.zeros((len(raw_demo), len(raw_demo[0]), 3))
52+
for x in range(len(raw_demo)):
53+
for y in range(len(raw_demo[0])):
54+
position_idx = int((raw_demo[x][y][0] - env_low[0]) / env_distance[0])
55+
velocity_idx = int((raw_demo[x][y][1] - env_low[1]) / env_distance[1])
56+
state_idx = position_idx + velocity_idx * self.one_feature
57+
58+
demonstrations[x][y][0] = state_idx
59+
demonstrations[x][y][1] = raw_demo[x][y][2]
60+
61+
return demonstrations
62+
63+
def expert_feature_expectations(self, demonstrations):
64+
"""
65+
Returns the feature expectations.
66+
:param demonstrations:
67+
:return:
68+
"""
69+
feature_expectations = np.zeros(self.feature_matrix.shape[0])
70+
71+
for demonstration in demonstrations:
72+
for state_idx, _, _ in demonstration:
73+
feature_expectations += self.feature_matrix[int(state_idx)]
74+
75+
feature_expectations /= demonstrations.shape[0]
76+
return feature_expectations
77+
78+
def state_to_idx(self, env, state):
79+
"""
80+
Converts state (pos, vel) to the integer value using the mountain car environment.
81+
:param state:
82+
:return:
83+
"""
84+
""" """
85+
env_low = env.observation_space.low
86+
env_high = env.observation_space.high
87+
env_distance = (env_high - env_low) / self.one_feature
88+
position_idx = int((state[0] - env_low[0]) / env_distance[0])
89+
velocity_idx = int((state[1] - env_low[1]) / env_distance[1])
90+
state_idx = position_idx + velocity_idx * self.one_feature
91+
return state_idx
92+
93+
def maxent_irl(self, expert, learner, learning_rate):
94+
"""
95+
Max Entropy Learning step.
96+
:param expert:
97+
:param learner:
98+
:param learning_rate:
99+
:return:
100+
"""
101+
gradient = expert - learner
102+
self.theta += learning_rate * gradient
103+
104+
# Clip theta
105+
for j in range(len(self.theta)):
106+
if self.theta[j] > 0:
107+
self.theta[j] = 0
108+
109+
def update_q_table(self, state, action, reward, next_state):
110+
"""
111+
Updates the Q table for a specified state and action.
112+
:param state:
113+
:param action:
114+
:param reward:
115+
:param next_state:
116+
:return:
117+
"""
118+
q_1 = self.q_table[state][action]
119+
q_2 = reward + self.gamma * max(self.q_table[next_state])
120+
self.q_table[state][action] += self.q_learning_rate * (q_2 - q_1)
121+
122+
123+
# Training Loop
124+
def train(agent, env, theta_learning_rate, episode_count=30000):
125+
demonstrations = agent.target.get_demonstrations()
126+
expert = agent.expert_feature_expectations(demonstrations)
127+
learner_feature_expectations = np.zeros(agent.n_states)
128+
129+
episodes, scores = [], []
130+
for episode in range(episode_count):
131+
state, info = env.reset()
132+
score = 0
133+
134+
# Mini-Batches:
135+
if (episode + 1) % 10 == 0:
136+
# calculate density
137+
learner = learner_feature_expectations / episode
138+
learner_feature_expectations = np.zeros(agent.n_states)
139+
140+
agent.maxent_irl(expert, learner, theta_learning_rate)
141+
142+
state = state
143+
while True:
144+
state_idx = agent.state_to_idx(env, state)
145+
action = np.argmax(agent.q_table[state_idx])
146+
147+
# Run one timestep of the environment's dynamics.
148+
next_state, reward, done, _, _ = env.step(action)
149+
150+
# Get pseudo-reward and update q table
151+
irl_reward = agent.get_reward(agent.n_states, state_idx)
152+
next_state_idx = agent.state_to_idx(env, next_state)
153+
agent.update_q_table(state_idx, action, irl_reward, next_state_idx)
154+
155+
# State counting for densitiy
156+
learner_feature_expectations += agent.feature_matrix[int(state_idx)]
157+
158+
score += reward
159+
state = next_state
160+
if done:
161+
scores.append(score)
162+
episodes.append(episode)
163+
break
164+
165+
if (episode + 1) % 1000 == 0:
166+
score_avg = np.mean(scores)
167+
print('{} episode score is {:.2f}'.format(episode, score_avg))
168+
save_plot_as_png(episodes, scores,
169+
f"src/irlwpython/learning_curves/maxent_{episode_count}_{episode}_qtable.png")
170+
save_heatmap_as_png(learner.reshape((20, 20)),
171+
f"src/irlwpython/heatmap/learner_{episode}_flat.png")
172+
save_heatmap_as_png(agent.theta.reshape((20, 20)),
173+
f"src/irlwpython/heatmap/theta_{episode}_flat.png")
174+
175+
np.save(f"src/irlwpython/results/maxent_{episode}_qtable", arr=agent.q_table)
176+
177+
178+
def save_heatmap_as_png(data, output_path, title=None, xlabel="Position", ylabel="Velocity"):
179+
"""
180+
Create a heatmap from a numpy array and save it as a PNG file.
181+
:param data: 2D numpy array containing the heatmap data.
182+
:param output_path: Output path for saving the PNG file.
183+
:param xlabel: Label for the x-axis (optional).
184+
:param ylabel: Label for the y-axis (optional).
185+
:param title: Title for the plot (optional).
186+
"""
187+
fig, ax = plt.subplots()
188+
im = ax.imshow(data, cmap='viridis', interpolation='nearest')
189+
plt.colorbar(im)
190+
191+
if xlabel:
192+
plt.xlabel(xlabel)
193+
if ylabel:
194+
plt.ylabel(ylabel)
195+
if title:
196+
plt.title(title)
197+
198+
plt.savefig(output_path, format='png')
199+
plt.close(fig)
200+
201+
202+
def save_plot_as_png(x, y, output_path, title=None, xlabel="Episodes", ylabel="Scores"):
203+
"""
204+
Create a line plot from x and y data and save it as a PNG file.
205+
:param x: 1D numpy array or list representing the x-axis values.
206+
:param y: 1D numpy array or list representing the y-axis values.
207+
:param output_path: Output path for saving the plot as a PNG file.
208+
:param xlabel: Label for the x-axis (optional).
209+
:param ylabel: Label for the y-axis (optional).
210+
:param title: Title for the plot (optional).
211+
"""
212+
fig, ax = plt.subplots()
213+
ax.plot(x, y)
214+
215+
if xlabel:
216+
plt.xlabel(xlabel)
217+
if ylabel:
218+
plt.ylabel(ylabel)
219+
if title:
220+
plt.title(title)
221+
222+
plt.savefig(output_path, format='png')
223+
plt.close(fig)
224+
225+
226+
# Main function
227+
if __name__ == "__main__":
228+
n_states = 400 # position - 20, velocity - 20 -> 20*20
229+
n_actions = 3 # Accelerate to the left: 0, Don’t accelerate: 1, Accelerate to the right: 2
230+
state_dim = 2 # Velocity and position
231+
one_feature = 20
232+
feature_matrix = np.eye(n_states)
233+
234+
gamma = 0.99
235+
q_learning_rate = 0.03
236+
237+
# Theta works as Rewards
238+
theta_learning_rate = 0.001
239+
theta = -(np.random.uniform(size=(n_states,)))
240+
241+
env = gym.make('MountainCar-v0')
242+
243+
q_table = np.zeros((n_states, n_actions))
244+
agent = MaxEntropyIRL(env, feature_matrix, one_feature, q_table, q_learning_rate, gamma, n_states, theta)
245+
246+
train(agent, env, theta_learning_rate)

0 commit comments

Comments
 (0)