Skip to content

Commit 43b964e

Browse files
authored
update learning curves and README (#8)
1 parent 67539d1 commit 43b964e

File tree

9 files changed

+15
-139
lines changed

9 files changed

+15
-139
lines changed

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
Inverse Reinforcement Learning Algorithm implementation with python.
66

77
Implemented Algorithms:
8-
- Maximum Entropy IRL
9-
- Maximum Entropy Deep IRL
8+
- Maximum Entropy IRL: [1]
9+
- Discrete Maximum Entropy Deep IRL: [2, 3]
10+
- IQ-Learn
1011

1112
Experiment:
1213
- Mountaincar: [gym](https://www.gymlibrary.dev/environments/classic_control/mountain_car/)
@@ -16,7 +17,11 @@ The implementation of MaxEntropyIRL and MountainCar is based on the implementati
1617

1718
# References
1819

19-
...
20+
[1] [BD. Ziebart, et al., "Maximum Entropy Inverse Reinforcement Learning", AAAI 2008](https://cdn.aaai.org/AAAI/2008/AAAI08-227.pdf).
21+
22+
[2] [Wulfmeier, et al., "Maximum entropy deep inverse reinforcement learning." arXiv preprint arXiv:1507.04888 (2015).](https://arxiv.org/abs/1507.04888)
23+
24+
[3] [Xi-liang Chen, et al., "A Study of Continuous Maximum Entropy Deep Inverse Reinforcement Learning", Mathematical Problems in Engineering, vol. 2019, Article ID 4834516, 8 pages, 2019. https://doi.org/10.1155/2019/4834516](https://www.hindawi.com/journals/mpe/2019/4834516/)
2025

2126
# Installation
2227

src/irlwpython/DiscreteMaxEntropyDeepIRL.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,14 @@ def train(self):
136136
score_avg = np.mean(scores)
137137
print('{} episode score is {:.2f}'.format(episode, score_avg))
138138
plt.plot(episodes, scores, 'b')
139-
plt.savefig("./learning_curves/maxent_30000_network.png")
139+
plt.savefig("./learning_curves/discretemaxentdeep_30000.png")
140140

141-
torch.save(self.q_network.state_dict(), "./results/maxent_30000_q_network.pth")
141+
torch.save(self.actor_network.state_dict(), "./results/discretemaxentdeep_30000_actor.pth")
142+
torch.save(self.critic_network.state_dict(), "./results/discretemaxentdeep_30000_critic.pth")
142143

143144
def test(self):
145+
assert 1 == 0 # TODO: not implemented yet
146+
144147
episodes, scores = [], []
145148

146149
for episode in range(10):
@@ -151,7 +154,7 @@ def test(self):
151154
self.target.env_render()
152155
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
153156

154-
action = torch.argmax(self.q_network(state_tensor)).item()
157+
action = torch.argmax(self.actor_network(state_tensor)).item()
155158
next_state, reward, done, _, _ = self.target.env_step(action)
156159

157160
score += reward
@@ -161,7 +164,7 @@ def test(self):
161164
scores.append(score)
162165
episodes.append(episode)
163166
plt.plot(episodes, scores, 'b')
164-
plt.savefig("./learning_curves/maxent_test_30000_network.png")
167+
plt.savefig("./learning_curves/discretemaxentdeep_test_30000.png")
165168
break
166169

167170
if episode % 1 == 0:
-5.34 MB
Binary file not shown.
2.97 KB
Loading
9.02 KB
Loading
357 Bytes
Loading
0 Bytes
Binary file not shown.

src/irlwpython/utils/utils.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

src/irlwpython/utils/zfilter.py

Lines changed: 0 additions & 86 deletions
This file was deleted.

0 commit comments

Comments
 (0)