Skip to content

Commit 2e74715

Browse files
authored
Merge pull request #1 from HokageM/feat/add_arguments
feat: add arguments for testing, training and render
2 parents 58dd9f9 + ee3f1dc commit 2e74715

File tree

10 files changed

+139
-100
lines changed

10 files changed

+139
-100
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ __pycache__/
66
# C extensions
77
*.so
88

9+
.idea/
10+
911
# Distribution / packaging
1012
.Python
1113
build/

README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,27 @@ Inverse Reinforcement Learning Algorithm implementation with Pytorch.
33

44
The implementation is based on: https://github.com/reinforcement-learning-kr/lets-do-irl
55

6-
Mountaincar experiment from: https://www.gymlibrary.dev/environments/classic_control/mountain_car/
6+
Mountaincar experiment from: https://www.gymlibrary.dev/environments/classic_control/mountain_car/
7+
8+
# Installation
9+
10+
```commandline
11+
cd IRLwPytorch
12+
pip install .
13+
```
14+
15+
# Usage
16+
17+
```commandline
18+
usage: irl [-h] [--version] [--training] [--testing] [--render]
19+
20+
Implementation of IRL algorithms
21+
22+
options:
23+
-h, --help show this help message and exit
24+
--version show program's version number and exit
25+
--training Enables training of model.
26+
--testing Enables testing of previously created model.
27+
--render Enables visualization of mountaincar.
28+
29+
```

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ testing =
7676
# console_scripts =
7777
# script_name = irlwpytorch.module:function
7878
# For example:
79-
# console_scripts =
80-
# fibonacci = irlwpytorch.skeleton:run
79+
console_scripts =
80+
irl = irlwpytorch.main:run
8181
# And any other entry points, for example:
8282
# pyscaffold.cli =
8383
# awesome = pyscaffoldext.awesome.extension:AwesomeExtension

src/irlwpytorch/MountainCar.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,76 @@
1+
import gym
2+
import numpy as np
3+
14
class MountainCar:
25

3-
def __init__(self):
4-
pass
6+
def __init__(self, animation, feature_matrix, one_feature, q_learning_rate, gamma):
7+
if animation:
8+
self.env = gym.make('MountainCar-v0', render_mode="human")
9+
else:
10+
self.env = gym.make('MountainCar-v0')
11+
self.feature_matrix = feature_matrix
12+
self.one_feature = one_feature
13+
self.q_table = None
14+
self.q_learning_rate = q_learning_rate
15+
self.gamma = gamma
516

617
def __enter__(self):
718
return self
819

920
def __exit__(self, exc_type, exc_val, exc_tb):
1021
pass
22+
23+
def set_q_table(self, table):
24+
self.q_table = table
25+
26+
def idx_demo(self, one_feature):
27+
env_low = self.env.observation_space.low
28+
env_high = self.env.observation_space.high
29+
env_distance = (env_high - env_low) / self.one_feature
30+
31+
raw_demo = np.load(file="expert_demo/expert_demo.npy")
32+
demonstrations = np.zeros((len(raw_demo), len(raw_demo[0]), 3))
33+
34+
for x in range(len(raw_demo)):
35+
for y in range(len(raw_demo[0])):
36+
position_idx = int((raw_demo[x][y][0] - env_low[0]) / env_distance[0])
37+
velocity_idx = int((raw_demo[x][y][1] - env_low[1]) / env_distance[1])
38+
state_idx = position_idx + velocity_idx * one_feature
39+
40+
demonstrations[x][y][0] = state_idx
41+
demonstrations[x][y][1] = raw_demo[x][y][2]
42+
43+
return demonstrations
44+
45+
def idx_state(self, state):
46+
env_low = self.env.observation_space.low
47+
env_high = self.env.observation_space.high
48+
env_distance = (env_high - env_low) / self.one_feature
49+
position_idx = int((state[0] - env_low[0]) / env_distance[0])
50+
velocity_idx = int((state[1] - env_low[1]) / env_distance[1])
51+
state_idx = position_idx + velocity_idx * self.one_feature
52+
return state_idx
53+
54+
def idx_to_state(self, state):
55+
""" Convert pos and vel about mounting car environment to the integer value"""
56+
env_low = self.env.observation_space.low
57+
env_high = self.env.observation_space.high
58+
env_distance = (env_high - env_low) / self.one_feature
59+
position_idx = int((state[0] - env_low[0]) / env_distance[0])
60+
velocity_idx = int((state[1] - env_low[1]) / env_distance[1])
61+
state_idx = position_idx + velocity_idx * self.one_feature
62+
return state_idx
63+
64+
def update_q_table(self, state, action, reward, next_state):
65+
q_1 = self.q_table[state][action]
66+
q_2 = reward + self.gamma * max(self.q_table[next_state])
67+
self.q_table[state][action] += self.q_learning_rate * (q_2 - q_1)
68+
69+
def env_render(self):
70+
self.env.render()
71+
72+
def env_reset(self):
73+
return self.env.reset()
74+
75+
def env_step(self, action):
76+
return self.env.step(action)
-7.97 KB
Loading
1.66 KB
Loading
19.8 KB
Loading

src/irlwpytorch/main.py

Lines changed: 43 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323
import argparse
2424
import gym
2525
import matplotlib.pyplot as plt
26-
import numpy as np
2726
import logging
2827
import numpy as np
2928
import sys
3029

31-
from MountainCar import MountainCar
32-
from MaxEntropyIRL import MaxEntropyIRL
30+
from .MountainCar import MountainCar
31+
from .MaxEntropyIRL import MaxEntropyIRL
3332

3433
# from irlwpytorch import __version__
3534

@@ -39,70 +38,9 @@
3938

4039
_logger = logging.getLogger(__name__)
4140

42-
n_states = 400 # position - 20, velocity - 20
43-
n_actions = 3
44-
one_feature = 20 # number of state per one feature
45-
q_table = np.zeros((n_states, n_actions)) # (400, 3)
46-
feature_matrix = np.eye((n_states)) # (400, 400)
47-
48-
gamma = 0.99
49-
q_learning_rate = 0.03
50-
theta_learning_rate = 0.05
51-
5241
np.random.seed(1)
5342

5443

55-
def idx_demo(env, one_feature):
56-
env_low = env.observation_space.low
57-
env_high = env.observation_space.high
58-
env_distance = (env_high - env_low) / one_feature
59-
60-
raw_demo = np.load(file="expert_demo/expert_demo.npy")
61-
demonstrations = np.zeros((len(raw_demo), len(raw_demo[0]), 3))
62-
63-
for x in range(len(raw_demo)):
64-
for y in range(len(raw_demo[0])):
65-
position_idx = int((raw_demo[x][y][0] - env_low[0]) / env_distance[0])
66-
velocity_idx = int((raw_demo[x][y][1] - env_low[1]) / env_distance[1])
67-
state_idx = position_idx + velocity_idx * one_feature
68-
69-
demonstrations[x][y][0] = state_idx
70-
demonstrations[x][y][1] = raw_demo[x][y][2]
71-
72-
return demonstrations
73-
74-
75-
def idx_state(env, state):
76-
env_low = env.observation_space.low
77-
env_high = env.observation_space.high
78-
env_distance = (env_high - env_low) / one_feature
79-
position_idx = int((state[0] - env_low[0]) / env_distance[0])
80-
velocity_idx = int((state[1] - env_low[1]) / env_distance[1])
81-
state_idx = position_idx + velocity_idx * one_feature
82-
return state_idx
83-
84-
85-
def update_q_table(state, action, reward, next_state):
86-
q_1 = q_table[state][action]
87-
q_2 = reward + gamma * max(q_table[next_state])
88-
q_table[state][action] += q_learning_rate * (q_2 - q_1)
89-
90-
91-
q_table = np.load(file="results/maxent_q_table.npy") # (400, 3)
92-
one_feature = 20 # number of state per one feature
93-
94-
95-
def idx_to_state(env, state):
96-
""" Convert pos and vel about mounting car environment to the integer value"""
97-
env_low = env.observation_space.low
98-
env_high = env.observation_space.high
99-
env_distance = (env_high - env_low) / one_feature
100-
position_idx = int((state[0] - env_low[0]) / env_distance[0])
101-
velocity_idx = int((state[1] - env_low[1]) / env_distance[1])
102-
state_idx = position_idx + velocity_idx * one_feature
103-
return state_idx
104-
105-
10644
def parse_args(args):
10745
"""Parse command line parameters
10846
@@ -119,6 +57,10 @@ def parse_args(args):
11957
action="version",
12058
# version=f"IRLwPytorch {__version__}",
12159
)
60+
parser.add_argument('--training', action='store_true', help="Enables training of model.")
61+
parser.add_argument('--testing', action='store_true',
62+
help="Enables testing of previously created model.")
63+
parser.add_argument('--render', action='store_true', help="Enables visualization of mountaincar.")
12264
return parser.parse_args(args)
12365

12466

@@ -147,36 +89,51 @@ def main(args):
14789
args = parse_args(args)
14890
_logger.debug("Starting crazy calculations...")
14991

150-
car = MountainCar()
92+
n_states = 400 # position - 20, velocity - 20
93+
n_actions = 3
94+
one_feature = 20 # number of state per one feature
95+
feature_matrix = np.eye((n_states)) # (400, 400)
96+
97+
gamma = 0.99
98+
q_learning_rate = 0.03
99+
theta_learning_rate = 0.05
100+
101+
car = None
102+
if args.render:
103+
car = MountainCar(True, feature_matrix, one_feature, q_learning_rate, gamma)
104+
else:
105+
car = MountainCar(False, feature_matrix, one_feature, q_learning_rate, gamma)
151106

152107
theta = -(np.random.uniform(size=(n_states,)))
153108
trainer = MaxEntropyIRL(feature_matrix, theta)
154109

155-
if False:
156-
env = gym.make('MountainCar-v0', render_mode="human")
157-
demonstrations = idx_demo(env, one_feature)
110+
if args.training:
111+
q_table = np.zeros((n_states, n_actions)) # (400, 3)
112+
car.set_q_table(q_table)
113+
114+
demonstrations = car.idx_demo(one_feature)
158115

159116
expert = trainer.expert_feature_expectations(demonstrations)
160117
learner_feature_expectations = np.zeros(n_states)
161118
episodes, scores = [], []
162119

163-
for episode in range(300):
164-
state = env.reset()
120+
for episode in range(30000):
121+
state = car.env_reset()
165122
score = 0
166123

167-
if (episode != 0 and episode == 100) or (episode > 100 and episode % 50 == 0):
124+
if (episode != 0 and episode == 10000) or (episode > 10000 and episode % 5000 == 0):
168125
learner = learner_feature_expectations / episode
169126
trainer.maxent_irl(expert, learner, theta_learning_rate)
170127

171128
state = state[0]
172129
while True:
173-
state_idx = idx_state(env, state)
130+
state_idx = car.idx_state(state)
174131
action = np.argmax(q_table[state_idx])
175-
next_state, reward, done, _, _ = env.step(action)
132+
next_state, reward, done, _, _ = car.env_step(action)
176133

177134
irl_reward = trainer.get_reward(n_states, state_idx)
178-
next_state_idx = idx_state(env, next_state)
179-
update_q_table(state_idx, action, irl_reward, next_state_idx)
135+
next_state_idx = car.idx_state(next_state)
136+
car.update_q_table(state_idx, action, irl_reward, next_state_idx)
180137

181138
learner_feature_expectations += trainer.get_feature_matrix()[int(state_idx)]
182139

@@ -187,28 +144,29 @@ def main(args):
187144
episodes.append(episode)
188145
break
189146

190-
if episode % 10 == 0:
147+
if episode % 100 == 0:
191148
score_avg = np.mean(scores)
192149
print('{} episode score is {:.2f}'.format(episode, score_avg))
193150
plt.plot(episodes, scores, 'b')
194-
plt.savefig("./learning_curves/maxent_300.png")
195-
np.save("./results/maxent_300_table", arr=q_table)
151+
plt.savefig("./learning_curves/maxent_30000.png")
152+
np.save("./results/maxent_30000_table", arr=q_table)
196153

197-
else:
198-
env = gym.make('MountainCar-v0', render_mode="human")
154+
if args.testing:
155+
q_table = np.load(file="results/maxent_q_table.npy") # (400, 3)
156+
car.set_q_table(q_table)
199157

200158
episodes, scores = [], []
201159

202160
for episode in range(10):
203-
state = env.reset()
161+
state = car.env_reset()
204162
score = 0
205163

206164
state = state[0]
207165
while True:
208-
env.render()
209-
state_idx = idx_to_state(env, state)
166+
car.env_render()
167+
state_idx = car.idx_to_state(state)
210168
action = np.argmax(q_table[state_idx])
211-
next_state, reward, done, _, _ = env.step(action)
169+
next_state, reward, done, _, _ = car.env_step(action)
212170

213171
score += reward
214172
state = next_state
@@ -217,7 +175,7 @@ def main(args):
217175
scores.append(score)
218176
episodes.append(episode)
219177
plt.plot(episodes, scores, 'b')
220-
plt.savefig("./learning_curves/maxent_test_300.png")
178+
plt.savefig("./learning_curves/maxent_test_30000.png")
221179
break
222180

223181
if episode % 1 == 0:
@@ -235,14 +193,4 @@ def run():
235193

236194

237195
if __name__ == "__main__":
238-
# ^ This is a guard statement that will prevent the following code from
239-
# being executed in the case someone imports this file instead of
240-
# executing it as a script.
241-
# https://docs.python.org/3/library/__main__.html
242-
243-
# After installing your project with pip, users can also run your Python
244-
# modules as scripts via the ``-m`` flag, as defined in PEP 338::
245-
#
246-
# python -m irlwpytorch.skeleton 42
247-
#
248196
run()
9.5 KB
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)