Skip to content

Commit 3529a8e

Browse files
authored
Merge pull request #2 from HokageM/feat/add_arguments
refactor: code structure
2 parents 2e74715 + ac0916f commit 3529a8e

File tree

5 files changed

+87
-115
lines changed

5 files changed

+87
-115
lines changed

src/irlwpytorch/MountainCar.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import gym
22
import numpy as np
3+
import matplotlib.pyplot as plt
4+
35

46
class MountainCar:
57

6-
def __init__(self, animation, feature_matrix, one_feature, q_learning_rate, gamma):
8+
def __init__(self, animation, feature_matrix, one_feature, q_learning_rate, gamma, n_states, trainer):
79
if animation:
810
self.env = gym.make('MountainCar-v0', render_mode="human")
911
else:
@@ -13,6 +15,8 @@ def __init__(self, animation, feature_matrix, one_feature, q_learning_rate, gamm
1315
self.q_table = None
1416
self.q_learning_rate = q_learning_rate
1517
self.gamma = gamma
18+
self.n_states = n_states
19+
self.trainer = trainer
1620

1721
def __enter__(self):
1822
return self
@@ -42,15 +46,6 @@ def idx_demo(self, one_feature):
4246

4347
return demonstrations
4448

45-
def idx_state(self, state):
46-
env_low = self.env.observation_space.low
47-
env_high = self.env.observation_space.high
48-
env_distance = (env_high - env_low) / self.one_feature
49-
position_idx = int((state[0] - env_low[0]) / env_distance[0])
50-
velocity_idx = int((state[1] - env_low[1]) / env_distance[1])
51-
state_idx = position_idx + velocity_idx * self.one_feature
52-
return state_idx
53-
5449
def idx_to_state(self, state):
5550
""" Convert pos and vel about mounting car environment to the integer value"""
5651
env_low = self.env.observation_space.low
@@ -74,3 +69,71 @@ def env_reset(self):
7469

7570
def env_step(self, action):
7671
return self.env.step(action)
72+
73+
def train(self, theta_learning_rate):
74+
demonstrations = self.idx_demo(self.one_feature)
75+
76+
expert = self.trainer.expert_feature_expectations(demonstrations)
77+
learner_feature_expectations = np.zeros(self.n_states)
78+
episodes, scores = [], []
79+
80+
for episode in range(30000):
81+
state = self.env_reset()
82+
score = 0
83+
84+
if (episode != 0 and episode == 10000) or (episode > 10000 and episode % 5000 == 0):
85+
learner = learner_feature_expectations / episode
86+
self.trainer.maxent_irl(expert, learner, theta_learning_rate)
87+
88+
state = state[0]
89+
while True:
90+
state_idx = self.idx_to_state(state)
91+
action = np.argmax(self.q_table[state_idx])
92+
next_state, reward, done, _, _ = self.env_step(action)
93+
94+
irl_reward = self.trainer.get_reward(self.n_states, state_idx)
95+
next_state_idx = self.idx_to_state(next_state)
96+
self.update_q_table(state_idx, action, irl_reward, next_state_idx)
97+
98+
learner_feature_expectations += self.trainer.get_feature_matrix()[int(state_idx)]
99+
100+
score += reward
101+
state = next_state
102+
if done:
103+
scores.append(score)
104+
episodes.append(episode)
105+
break
106+
107+
if episode % 100 == 0:
108+
score_avg = np.mean(scores)
109+
print('{} episode score is {:.2f}'.format(episode, score_avg))
110+
plt.plot(episodes, scores, 'b')
111+
plt.savefig("./learning_curves/maxent_30000.png")
112+
np.save("./results/maxent_30000_table", arr=self.q_table)
113+
114+
def test(self):
115+
episodes, scores = [], []
116+
117+
for episode in range(10):
118+
state = self.env_reset()
119+
score = 0
120+
121+
state = state[0]
122+
while True:
123+
self.env_render()
124+
state_idx = self.idx_to_state(state)
125+
action = np.argmax(self.q_table[state_idx])
126+
next_state, reward, done, _, _ = self.env_step(action)
127+
128+
score += reward
129+
state = next_state
130+
131+
if done:
132+
scores.append(score)
133+
episodes.append(episode)
134+
plt.plot(episodes, scores, 'b')
135+
plt.savefig("./learning_curves/maxent_test_30000.png")
136+
break
137+
138+
if episode % 1 == 0:
139+
print('{} episode score is {:.2f}'.format(episode, score))
1.04 KB
Loading
-3.2 KB
Loading

src/irlwpytorch/main.py

Lines changed: 14 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,12 @@
1-
"""
2-
This is a skeleton file that can serve as a starting point for a Python
3-
console script. To run this script uncomment the following lines in the
4-
``[options.entry_points]`` section in ``setup.cfg``::
5-
6-
console_scripts =
7-
fibonacci = irlwpytorch.skeleton:run
8-
9-
Then run ``pip install .`` (or ``pip install -e .`` for editable mode)
10-
which will install the command ``fibonacci`` inside your current environment.
11-
12-
Besides console scripts, the header (i.e. until ``_logger``...) of this file can
13-
also be used as template for Python modules.
14-
15-
Note:
16-
This file can be renamed depending on your needs or safely removed if not needed.
17-
18-
References:
19-
- https://setuptools.pypa.io/en/latest/userguide/entry_point.html
20-
- https://pip.pypa.io/en/stable/reference/pip_install
21-
"""
22-
231
import argparse
24-
import gym
25-
import matplotlib.pyplot as plt
262
import logging
273
import numpy as np
284
import sys
295

30-
from .MountainCar import MountainCar
31-
from .MaxEntropyIRL import MaxEntropyIRL
6+
from MountainCar import MountainCar
7+
from MaxEntropyIRL import MaxEntropyIRL
328

33-
# from irlwpytorch import __version__
9+
#from irlwpytorch import __version__
3410

3511
__author__ = "HokageM"
3612
__copyright__ = "HokageM"
@@ -55,7 +31,7 @@ def parse_args(args):
5531
parser.add_argument(
5632
"--version",
5733
action="version",
58-
# version=f"IRLwPytorch {__version__}",
34+
# version=f"IRLwPytorch {__version__}",
5935
)
6036
parser.add_argument('--training', action='store_true', help="Enables training of model.")
6137
parser.add_argument('--testing', action='store_true',
@@ -92,103 +68,36 @@ def main(args):
9268
n_states = 400 # position - 20, velocity - 20
9369
n_actions = 3
9470
one_feature = 20 # number of state per one feature
95-
feature_matrix = np.eye((n_states)) # (400, 400)
71+
feature_matrix = np.eye(n_states) # (400, 400)
9672

9773
gamma = 0.99
9874
q_learning_rate = 0.03
9975
theta_learning_rate = 0.05
10076

101-
car = None
102-
if args.render:
103-
car = MountainCar(True, feature_matrix, one_feature, q_learning_rate, gamma)
104-
else:
105-
car = MountainCar(False, feature_matrix, one_feature, q_learning_rate, gamma)
106-
10777
theta = -(np.random.uniform(size=(n_states,)))
10878
trainer = MaxEntropyIRL(feature_matrix, theta)
10979

80+
if args.render:
81+
car = MountainCar(True, feature_matrix, one_feature, q_learning_rate, gamma, n_states, trainer)
82+
else:
83+
car = MountainCar(False, feature_matrix, one_feature, q_learning_rate, gamma, n_states, trainer)
84+
11085
if args.training:
111-
q_table = np.zeros((n_states, n_actions)) # (400, 3)
86+
q_table = np.zeros((n_states, n_actions))
11287
car.set_q_table(q_table)
11388

114-
demonstrations = car.idx_demo(one_feature)
115-
116-
expert = trainer.expert_feature_expectations(demonstrations)
117-
learner_feature_expectations = np.zeros(n_states)
118-
episodes, scores = [], []
119-
120-
for episode in range(30000):
121-
state = car.env_reset()
122-
score = 0
123-
124-
if (episode != 0 and episode == 10000) or (episode > 10000 and episode % 5000 == 0):
125-
learner = learner_feature_expectations / episode
126-
trainer.maxent_irl(expert, learner, theta_learning_rate)
127-
128-
state = state[0]
129-
while True:
130-
state_idx = car.idx_state(state)
131-
action = np.argmax(q_table[state_idx])
132-
next_state, reward, done, _, _ = car.env_step(action)
133-
134-
irl_reward = trainer.get_reward(n_states, state_idx)
135-
next_state_idx = car.idx_state(next_state)
136-
car.update_q_table(state_idx, action, irl_reward, next_state_idx)
137-
138-
learner_feature_expectations += trainer.get_feature_matrix()[int(state_idx)]
139-
140-
score += reward
141-
state = next_state
142-
if done:
143-
scores.append(score)
144-
episodes.append(episode)
145-
break
146-
147-
if episode % 100 == 0:
148-
score_avg = np.mean(scores)
149-
print('{} episode score is {:.2f}'.format(episode, score_avg))
150-
plt.plot(episodes, scores, 'b')
151-
plt.savefig("./learning_curves/maxent_30000.png")
152-
np.save("./results/maxent_30000_table", arr=q_table)
89+
car.train(theta_learning_rate)
15390

15491
if args.testing:
155-
q_table = np.load(file="results/maxent_q_table.npy") # (400, 3)
92+
q_table = np.load(file="./results/maxent_q_table.npy") # (400, 3)
15693
car.set_q_table(q_table)
15794

158-
episodes, scores = [], []
159-
160-
for episode in range(10):
161-
state = car.env_reset()
162-
score = 0
163-
164-
state = state[0]
165-
while True:
166-
car.env_render()
167-
state_idx = car.idx_to_state(state)
168-
action = np.argmax(q_table[state_idx])
169-
next_state, reward, done, _, _ = car.env_step(action)
170-
171-
score += reward
172-
state = next_state
173-
174-
if done:
175-
scores.append(score)
176-
episodes.append(episode)
177-
plt.plot(episodes, scores, 'b')
178-
plt.savefig("./learning_curves/maxent_test_30000.png")
179-
break
180-
181-
if episode % 1 == 0:
182-
print('{} episode score is {:.2f}'.format(episode, score))
95+
car.test()
18396

18497
_logger.info("Script ends here")
18598

18699

187100
def run():
188-
"""Calls :func:`main` passing the CLI arguments extracted from :obj:`sys.argv`
189-
190-
This function can be used as entry point to create console scripts with setuptools.
191-
"""
192101
main(sys.argv[1:])
193102

194103

0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)