Skip to content

Commit 67539d1

Browse files
authored
feat: add DiscreteMaxEntropyDeepIRL (#7)
* add DiscreteMaxEntropyDeepIRL * update main
1 parent 9805a6e commit 67539d1

File tree

5 files changed

+198
-24
lines changed

5 files changed

+198
-24
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@ pip install .
2828
# Usage
2929

3030
```commandline
31-
usage: irl [-h] [--version] [--training] [--testing] [--render]
31+
usage: irl [-h] [--version] [--training] [--testing] [--render] ALGORITHM
3232
3333
Implementation of IRL algorithms
3434
35+
positional arguments:
36+
ALGORITHM Currently supported training algorithm: [max-entropy, discrete-max-entropy-deep]
37+
3538
options:
3639
-h, --help show this help message and exit
3740
--version show program's version number and exit
3841
--training Enables training of model.
3942
--testing Enables testing of previously created model.
4043
--render Enables visualization of mountaincar.
41-
4244
```

src/irlwpython/MaxEntropyDeepIRL.py renamed to src/irlwpython/ContinuousMaxEntropyDeepIRL.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,20 @@ def get_demonstrations(self):
107107
demonstrations[x][y][0] = state_idx
108108
demonstrations[x][y][1] = raw_demo[x][y][2]
109109

110+
print(demonstrations)
110111
return demonstrations
111112

113+
def get_expert_state_frequencies(self):
114+
raw_demo = np.load(file="expert_demo/expert_demo.npy")
115+
expert_state_frequencies = []
116+
return expert_state_frequencies
117+
112118
def train(self):
113119
demonstrations = self.get_demonstrations()
114120
expert = self.expert_feature_expectations(demonstrations)
115121

122+
expert_state_frequencies = self.get_expert_state_frequencies()
123+
116124
learner_feature_expectations = torch.zeros(self.state_dim, requires_grad=True) # Add requires_grad=True
117125
episodes, scores = [], []
118126

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import gym
2+
import numpy as np
3+
import torch
4+
import torch.optim as optim
5+
import torch.nn as nn
6+
import matplotlib.pyplot as plt
7+
8+
9+
class ActorNetwork(nn.Module):
10+
def __init__(self, num_inputs, num_output, hidden_size):
11+
super(ActorNetwork, self).__init__()
12+
self.fc1 = nn.Linear(num_inputs, hidden_size)
13+
self.fc2 = nn.Linear(hidden_size, hidden_size)
14+
self.fc3 = nn.Linear(hidden_size, num_output)
15+
16+
def forward(self, x):
17+
x = nn.functional.relu(self.fc1(x))
18+
x = nn.functional.relu(self.fc2(x))
19+
return self.fc3(x) # torch.nn.functional.softmax(self.fc3(x))
20+
21+
22+
class CriticNetwork(nn.Module):
23+
def __init__(self, num_inputs, hidden_size):
24+
super(CriticNetwork, self).__init__()
25+
self.fc1 = nn.Linear(num_inputs, hidden_size)
26+
self.fc2 = nn.Linear(hidden_size, hidden_size)
27+
self.fc3 = nn.Linear(hidden_size, 1)
28+
29+
self.theta_layer = nn.Linear(hidden_size, 3)
30+
31+
def forward(self, x):
32+
x_ = nn.functional.relu(self.fc1(x))
33+
x_ = nn.functional.relu(self.fc2(x_))
34+
theta_ = self.theta_layer(x_)
35+
return self.fc3(x_) + torch.matmul(theta_, x)
36+
37+
38+
class DiscreteMaxEntropyDeepIRL:
39+
def __init__(self, target, state_dim, action_dim, feature_matrix=None, learning_rate=0.001, gamma=0.99,
40+
num_epochs=1000):
41+
self.feat_matrix = feature_matrix
42+
self.one_feature = 20
43+
44+
self.target = target
45+
self.state_dim = state_dim
46+
self.action_dim = action_dim
47+
self.learning_rate = learning_rate
48+
49+
self.gamma = gamma
50+
self.num_epochs = num_epochs
51+
self.actor_network = ActorNetwork(state_dim, action_dim, 100)
52+
self.critic_network = CriticNetwork(state_dim + 1, 100)
53+
self.optimizer_actor = optim.Adam(self.actor_network.parameters(), lr=learning_rate)
54+
self.optimizer_critic = optim.Adam(self.critic_network.parameters(), lr=learning_rate)
55+
56+
def get_reward(self, state, action):
57+
state_action = list(state) + list([action])
58+
state_action = torch.Tensor(state_action)
59+
return self.critic_network(state_action)
60+
61+
def expert_feature_expectations(self, demonstrations):
62+
feature_expectations = torch.zeros(400)
63+
64+
for demonstration in demonstrations:
65+
for state, _, _ in demonstration:
66+
state_tensor = torch.tensor(state, dtype=torch.float32)
67+
feature_expectations += state_tensor.squeeze()
68+
69+
feature_expectations /= demonstrations.shape[0]
70+
return feature_expectations
71+
72+
def maxent_irl(self, expert, learner):
73+
# Update critic network
74+
75+
self.optimizer_critic.zero_grad()
76+
77+
# Loss function for critic network
78+
loss_critic = torch.nn.functional.mse_loss(learner, expert)
79+
loss_critic.backward()
80+
81+
self.optimizer_critic.step()
82+
83+
def update_q_network(self, state_array, action, reward, next_state):
84+
self.optimizer_actor.zero_grad()
85+
86+
state_tensor = torch.tensor(state_array, dtype=torch.float32)
87+
next_state_tensor = torch.tensor(next_state, dtype=torch.float32)
88+
89+
q_values = self.actor_network(state_tensor)
90+
q_1 = self.actor_network(state_tensor)[action]
91+
92+
q_2 = reward + self.gamma * max(self.actor_network(next_state_tensor))
93+
next_q_values = reward + self.gamma * (q_2 - q_1) # self.actor_network(next_state_tensor)
94+
95+
loss_actor = nn.functional.mse_loss(q_values, next_q_values)
96+
loss_actor.backward()
97+
self.optimizer_actor.step()
98+
99+
def train(self):
100+
demonstrations = self.target.get_demonstrations()
101+
expert = self.expert_feature_expectations(demonstrations)
102+
103+
learner_feature_expectations = torch.zeros(400, requires_grad=True)
104+
episodes, scores = [], []
105+
106+
for episode in range(self.num_epochs):
107+
state, info = self.target.env_reset()
108+
score = 0
109+
110+
while True:
111+
state_tensor = torch.tensor(state, dtype=torch.float32)
112+
113+
q_state = self.actor_network(state_tensor)
114+
action = torch.argmax(q_state).item()
115+
next_state, reward, done, _, _ = self.target.env_step(action)
116+
117+
# Actor update
118+
irl_reward = self.get_reward(state, action)
119+
self.update_q_network(state, action, irl_reward, next_state)
120+
121+
score += reward
122+
state = next_state
123+
if done:
124+
scores.append(score)
125+
episodes.append(episode)
126+
break
127+
128+
# Critic update
129+
state_idx = state[0] + state[1] * self.one_feature
130+
learner_feature_expectations = learner_feature_expectations + torch.Tensor(
131+
self.feat_matrix[int(state_idx)])
132+
learner = learner_feature_expectations / episode
133+
self.maxent_irl(expert, learner)
134+
135+
if episode % 1 == 0:
136+
score_avg = np.mean(scores)
137+
print('{} episode score is {:.2f}'.format(episode, score_avg))
138+
plt.plot(episodes, scores, 'b')
139+
plt.savefig("./learning_curves/maxent_30000_network.png")
140+
141+
torch.save(self.q_network.state_dict(), "./results/maxent_30000_q_network.pth")
142+
143+
def test(self):
144+
episodes, scores = [], []
145+
146+
for episode in range(10):
147+
state = self.target.env_reset()
148+
score = 0
149+
150+
while True:
151+
self.target.env_render()
152+
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
153+
154+
action = torch.argmax(self.q_network(state_tensor)).item()
155+
next_state, reward, done, _, _ = self.target.env_step(action)
156+
157+
score += reward
158+
state = next_state
159+
160+
if done:
161+
scores.append(score)
162+
episodes.append(episode)
163+
plt.plot(episodes, scores, 'b')
164+
plt.savefig("./learning_curves/maxent_test_30000_network.png")
165+
break
166+
167+
if episode % 1 == 0:
168+
print('{} episode score is {:.2f}'.format(episode, score))

src/irlwpython/MaxEntropyIRL.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ def train(self, theta_learning_rate):
123123
self.update_q_table(state_idx, action, irl_reward, next_state_idx)
124124

125125
# State counting for densitiy
126-
learner_feature_expectations += self.get_feature_matrix()[int(state_idx)]
127-
128-
print(reward, irl_reward)
126+
learner_feature_expectations += self.feature_matrix[int(state_idx)]
129127

130128
score += reward
131129
state = next_state

src/irlwpython/main.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import numpy as np
44
import sys
55

6-
from MountainCar import MountainCar
7-
from MaxEntropyIRL import MaxEntropyIRL
8-
from MaxEntropyDeepIRL import MaxEntropyDeepIRL
6+
from irlwpython.MountainCar import MountainCar
7+
from irlwpython.MaxEntropyIRL import MaxEntropyIRL
8+
from irlwpython.DiscreteMaxEntropyDeepIRL import DiscreteMaxEntropyDeepIRL
99

10-
#from irlwpython import __version__
11-
12-
import gym
10+
from irlwpython import __version__
1311

1412
__author__ = "HokageM"
1513
__copyright__ = "HokageM"
@@ -34,9 +32,10 @@ def parse_args(args):
3432
parser.add_argument(
3533
"--version",
3634
action="version",
37-
# version=f"IRLwPython {__version__}",
35+
version=f"IRLwPython {__version__}",
3836
)
39-
parser.add_argument('--deep', action='store_true', help="Uses Max Entropy Deep IRL.")
37+
parser.add_argument('algorithm', metavar='ALGORITHM', type=str,
38+
help='Currently supported training algorithm: [max-entropy, discrete-max-entropy-deep]')
4039
parser.add_argument('--training', action='store_true', help="Enables training of model.")
4140
parser.add_argument('--testing', action='store_true',
4241
help="Enables testing of previously created model.")
@@ -86,25 +85,24 @@ def main(args):
8685
else:
8786
car = MountainCar(False, one_feature)
8887

89-
if args.deep:
90-
91-
# Create MountainCar environment
92-
env = gym.make('MountainCar-v0', render_mode="human")
93-
state_dim = env.observation_space.shape[0]
94-
action_dim = env.action_space.n
88+
if args.algorithm == "discrete-max-entropy-deep" and args.training:
89+
state_dim = 2
9590

9691
# Run MaxEnt Deep IRL using MountainCar environment
97-
maxent_deep_irl_agent = MaxEntropyDeepIRL(env, state_dim, action_dim)
92+
maxent_deep_irl_agent = DiscreteMaxEntropyDeepIRL(car, state_dim, n_actions, feature_matrix)
9893
maxent_deep_irl_agent.train()
99-
maxent_deep_irl_agent.test()
94+
# maxent_deep_irl_agent.test()
95+
96+
if args.algorithm == "discrete-max-entropy-deep" and args.testing:
97+
pass
10098

101-
if args.training:
99+
if args.algorithm == "max-entropy" and args.training:
102100
q_table = np.zeros((n_states, n_actions))
103101
trainer = MaxEntropyIRL(car, feature_matrix, one_feature, q_table, q_learning_rate, gamma, n_states, theta)
104102
trainer.train(theta_learning_rate)
105103

106-
if args.testing:
107-
q_table = np.load(file="./results/maxent_q_table.npy") # (400, 3)
104+
if args.algorithm == "max-entropy" and args.testing:
105+
q_table = np.load(file="./results/maxent_q_table.npy")
108106
trainer = MaxEntropyIRL(car, feature_matrix, one_feature, q_table, q_learning_rate, gamma, n_states, theta)
109107
trainer.test()
110108

0 commit comments

Comments
 (0)