Skip to content

Commit 3134f5c

Browse files
authored
fix(OutputHandler): check if directory exists (#16)
* Check if directory exists * flake8
1 parent 8cd2b79 commit 3134f5c

File tree

4 files changed

+52
-28
lines changed

4 files changed

+52
-28
lines changed

src/irlwpython/MaxEntropyDeepIRL.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.optim as optim
66
import torch.nn as nn
77

8-
from irlwpython.FigurePrinter import FigurePrinter
8+
from irlwpython.OutputHandler import OutputHandler
99

1010

1111
class QNetwork(nn.Module):
@@ -17,7 +17,7 @@ def __init__(self, input_size, output_size):
1717
self.relu2 = nn.ReLU()
1818
self.output_layer = nn.Linear(32, output_size)
1919

20-
self.printer = FigurePrinter()
20+
self.output_hand = OutputHandler()
2121

2222
def forward(self, state):
2323
x = self.fc1(state)
@@ -46,7 +46,7 @@ def __init__(self, target, state_dim, action_size, feature_matrix, one_feature,
4646
self.theta_learning_rate = theta_learning_rate
4747
self.theta = theta
4848

49-
self.printer = FigurePrinter()
49+
self.printer = OutputHandler()
5050

5151
def select_action(self, state, epsilon):
5252
"""

src/irlwpython/MaxEntropyDeepRL.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.optim as optim
66
import torch.nn as nn
77

8-
from irlwpython.FigurePrinter import FigurePrinter
8+
from irlwpython.OutputHandler import OutputHandler
99

1010

1111
class QNetwork(nn.Module):
@@ -17,7 +17,7 @@ def __init__(self, input_size, output_size):
1717
self.relu2 = nn.ReLU()
1818
self.output_layer = nn.Linear(32, output_size)
1919

20-
self.printer = FigurePrinter()
20+
self.output_hand = OutputHandler()
2121

2222
def forward(self, state):
2323
x = self.fc1(state)
@@ -42,7 +42,7 @@ def __init__(self, target, state_dim, action_size, feature_matrix, one_feature,
4242

4343
self.gamma = gamma
4444

45-
self.printer = FigurePrinter()
45+
self.output_hand = OutputHandler()
4646

4747
def select_action(self, state, epsilon):
4848
"""
@@ -150,17 +150,18 @@ def train(self, n_states, episodes=30000, max_steps=200,
150150
if (episode + 1) % 1000 == 0:
151151
score_avg = np.mean(scores)
152152
print('{} episode average score is {:.2f}'.format(episode, score_avg))
153-
self.printer.save_plot_as_png(episode_arr, scores,
154-
f"../learning_curves/maxent_{episodes}_{episode}_qnetwork_RL.png")
155-
self.printer.save_heatmap_as_png(learner.reshape((20, 20)), f"../heatmap/learner_{episode}_deep_RL.png")
156-
self.printer.save_heatmap_as_png(self.theta.reshape((20, 20)),
157-
f"../heatmap/theta_{episode}_deep_RL.png")
153+
self.output_hand.save_plot_as_png(episode_arr, scores,
154+
f"../learning_curves/maxent_{episodes}_{episode}_qnetwork_RL.png")
155+
self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)),
156+
f"../heatmap/learner_{episode}_deep_RL.png")
157+
self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)),
158+
f"../heatmap/theta_{episode}_deep_RL.png")
158159

159160
torch.save(self.q_network.state_dict(), f"../results/maxent_{episodes}_{episode}_network_main.pth")
160161

161162
if episode == episodes - 1:
162-
self.printer.save_plot_as_png(episode_arr, scores,
163-
f"../learning_curves/maxentdeep_{episodes}_qdeep_RL.png")
163+
self.output_hand.save_plot_as_png(episode_arr, scores,
164+
f"../learning_curves/maxentdeep_{episodes}_qdeep_RL.png")
164165

165166
torch.save(self.q_network.state_dict(), f"src/irlwpython/results/maxentdeep_{episodes}_q_network_RL.pth")
166167

@@ -192,6 +193,6 @@ def test(self, model_path, epsilon=0.01, repeats=100):
192193
if episode % 1 == 0:
193194
print('{} episode score is {:.2f}'.format(episode, score))
194195

195-
self.printer.save_plot_as_png(episodes, scores,
196-
"src/irlwpython/learning_curves"
197-
"/test_maxentropydeep_best_model_RL_results.png")
196+
self.output_hand.save_plot_as_png(episodes, scores,
197+
"src/irlwpython/learning_curves"
198+
"/test_maxentropydeep_best_model_RL_results.png")

src/irlwpython/MaxEntropyIRL.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88

9-
from irlwpython.FigurePrinter import FigurePrinter
9+
from irlwpython.OutputHandler import OutputHandler
1010

1111

1212
class MaxEntropyIRL:
@@ -20,7 +20,7 @@ def __init__(self, target, feature_matrix, one_feature, q_table, q_learning_rate
2020
self.gamma = gamma
2121
self.n_states = n_states
2222

23-
self.printer = FigurePrinter()
23+
self.output_hand = OutputHandler()
2424

2525
def get_feature_matrix(self):
2626
"""
@@ -133,13 +133,13 @@ def train(self, theta_learning_rate, episode_count=30000):
133133
if (episode + 1) % 1000 == 0:
134134
score_avg = np.mean(scores)
135135
print('{} episode score is {:.2f}'.format(episode, score_avg))
136-
self.printer.save_plot_as_png(episodes, scores,
137-
f"src/irlwpython/learning_curves/"
138-
f"maxent_{episode_count}_{episode}_qtable.png")
139-
self.printer.save_heatmap_as_png(learner.reshape((20, 20)),
140-
f"src/irlwpython/heatmap/learner_{episode}_flat.png")
141-
self.printer.save_heatmap_as_png(self.theta.reshape((20, 20)),
142-
f"src/irlwpython/heatmap/theta_{episode}_flat.png")
136+
self.output_hand.save_plot_as_png(episodes, scores,
137+
f"src/irlwpython/learning_curves/"
138+
f"maxent_{episode_count}_{episode}_qtable.png")
139+
self.output_hand.save_heatmap_as_png(learner.reshape((20, 20)),
140+
f"src/irlwpython/heatmap/learner_{episode}_flat.png")
141+
self.output_hand.save_heatmap_as_png(self.theta.reshape((20, 20)),
142+
f"src/irlwpython/heatmap/theta_{episode}_flat.png")
143143

144144
np.save(f"src/irlwpython/results/maxent_{episode}_qtable", arr=self.q_table)
145145

@@ -172,5 +172,5 @@ def test(self, repeats=100):
172172
if episode % 1 == 0:
173173
print('{} episode score is {:.2f}'.format(episode, score))
174174

175-
self.printer.save_plot_as_png(episodes, scores,
176-
"src/irlwpython/learning_curves/test_maxentropy_flat.png")
175+
self.output_hand.save_plot_as_png(episodes, scores,
176+
"src/irlwpython/learning_curves/test_maxentropy_flat.png")

src/irlwpython/FigurePrinter.py renamed to src/irlwpython/OutputHandler.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import matplotlib.pyplot as plt
2+
import os
23

34

4-
class FigurePrinter:
5+
class OutputHandler:
56
def __int__(self):
67
pass
78

@@ -25,6 +26,11 @@ def save_heatmap_as_png(self, data, output_path, title=None, xlabel="Position",
2526
if title:
2627
plt.title(title)
2728

29+
target_dir = os.path.basename(output_path)
30+
if not os.path.isdir(target_dir):
31+
print(f"Creating directory {target_dir}")
32+
os.mkdir(target_dir)
33+
2834
plt.savefig(output_path, format='png')
2935
plt.close(fig)
3036

@@ -48,5 +54,22 @@ def save_plot_as_png(self, x, y, output_path, title=None, xlabel="Episodes", yla
4854
if title:
4955
plt.title(title)
5056

57+
target_dir = os.path.basename(output_path)
58+
if not os.path.isdir(target_dir):
59+
print(f"Creating directory {target_dir}")
60+
os.mkdir(target_dir)
61+
5162
plt.savefig(output_path, format='png')
5263
plt.close(fig)
64+
65+
def save_network(self, network, output_path):
66+
target_dir = os.path.basename(output_path)
67+
if not os.path.isdir(target_dir):
68+
print(f"Creating directory {target_dir}")
69+
os.mkdir(target_dir)
70+
71+
def save_qtable(self, qtable, output_path):
72+
target_dir = os.path.basename(output_path)
73+
if not os.path.isdir(target_dir):
74+
print(f"Creating directory {target_dir}")
75+
os.mkdir(target_dir)

0 commit comments

Comments
 (0)