Skip to content

Commit 02a68d7

Browse files
committed
Support state with history length = 1
1 parent d410dcb commit 02a68d7

File tree

6 files changed

+29
-9
lines changed

6 files changed

+29
-9
lines changed

fruit/buffers/replay.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def append(self, state, action, reward, next_state, terminal):
5353
else:
5454
self.start_index = (self.start_index + 1) % self.max_size
5555

56-
self.states[insert_index] = state[-1]
56+
if self.state_history > 1:
57+
self.states[insert_index] = state[-1]
58+
else:
59+
self.states[insert_index] = state
5760

5861
def get_state(self, index):
5962
if self.current_size < self.max_size:

fruit/buffers/tree.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,10 @@ def append(self, state=None, action=0, reward=0, next_state=None, terminal=False
173173
self.start_index = (self.start_index + 1) % self.max_size
174174
self.__modify(self.num_of_levels-1, insert_index, pre_p**self.alpha, priority**self.alpha)
175175

176-
self.states[insert_index] = state[-1]
176+
if self.state_history > 1:
177+
self.states[insert_index] = state[-1]
178+
else:
179+
self.states[insert_index] = state
177180

178181
def __update(self, new_level, old_index, new_value):
179182
new_index = int(old_index/2)

fruit/learners/dqn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ def __init__(self, agent, name, environment, network, global_dict, report_freque
1919
global experience_replay
2020
with global_dict[AgentMonitor.Q_LOCK]:
2121
if experience_replay is None:
22-
experience_replay = SyncExperienceReplay(experience_replay_size)
22+
experience_replay = SyncExperienceReplay(experience_replay_size,
23+
state_history=network.network_config.get_history_length())
2324
self.replay = experience_replay
2425
self.batch_size = batch_size
2526
self.warmup_steps = warmup_steps

fruit/monitor/monitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,5 +208,10 @@ def run_epochs(self, learners):
208208
self.shared_dict[AgentMonitor.Q_FINISH] = True
209209
for t in threads:
210210
t.join()
211+
212+
current_epoch = self.shared_dict[AgentMonitor.Q_GLOBAL_STEPS] / self.epoch_steps
213+
et = time.time()
214+
self.__print_log(et - st, current_epoch)
215+
211216
print('All threads stopped')
212217
return self.shared_dict[AgentMonitor.Q_REWARD_LIST]

fruit/samples/multi_objectives_test.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,19 @@ def train_multi_objective_agent_mountain_car():
4040
agent.train()
4141

4242

43-
def train_multi_objective_dqn_agent(is_linear=True, extended_config=True):
44-
# Create a Deep Sea Treasure game
45-
game = DeepSeaTreasure(graphical_state=True, width=5, seed=100, render=False, max_treasure=100, speed=1000)
43+
def train_multi_objective_dqn_agent(is_linear=False, extended_config=False):
44+
if extended_config:
45+
# Create a Deep Sea Treasure game
46+
game = DeepSeaTreasure(graphical_state=True, width=5, seed=100, render=False, max_treasure=100, speed=1000)
47+
48+
# Put game into fruit wrapper
49+
environment = FruitEnvironment(game, max_episode_steps=60, state_processor=AtariProcessor())
50+
else:
51+
# Create a Deep Sea Treasure game
52+
game = DeepSeaTreasure(graphical_state=False, width=5, seed=100, render=False, max_treasure=100, speed=1000)
4653

47-
# Put game into fruit wrapper
48-
environment = FruitEnvironment(game, max_episode_steps=60, state_processor=AtariProcessor())
54+
# Put game into fruit wrapper
55+
environment = FruitEnvironment(game, max_episode_steps=60)
4956

5057
# Get treasures
5158
treasures = game.get_treasure()

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ Pillow==4.3.0
55
psutil==5.6.2
66
statsmodels==0.10.1
77
tensorflow-gpu==1.12.0
8-
matplotlib==3.1.1
8+
matplotlib==3.1.1
9+
pygame==1.9.6

0 commit comments

Comments
 (0)