Skip to content

Commit ac834ef

Browse files
authored
adding Memory class
1 parent 758a3eb commit ac834ef

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

mitdeeplearning/lab3.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,37 @@ def pong_change(prev, curr):
4141
I = (I - I.min()) / (I.max() - I.min() + 1e-10)
4242
return I
4343

44+
class Memory:
45+
def __init__(self):
46+
self.clear()
47+
48+
# Resets/restarts the memory buffer
49+
def clear(self):
50+
self.observations = []
51+
self.actions = []
52+
self.rewards = []
53+
54+
# Add observations, actions, rewards to memory
55+
def add_to_memory(self, new_observation, new_action, new_reward):
56+
self.observations.append(new_observation)
57+
'''TODO: update the list of actions with new action'''
58+
self.actions.append(new_action) # TODO
59+
# ['''TODO''']
60+
'''TODO: update the list of rewards with new reward'''
61+
self.rewards.append(new_reward) # TODO
62+
# ['''TODO''']
63+
64+
# Helper function to combine a list of Memory objects into a single Memory.
65+
# This will be very useful for batching.
66+
def aggregate_memories(memories):
67+
batch_memory = Memory()
68+
69+
for memory in memories:
70+
for step in zip(memory.observations, memory.actions, memory.rewards):
71+
batch_memory.add_to_memory(*step)
72+
73+
return batch_memory
74+
4475
def parallelized_collect_rollout(batch_size, envs, model, choose_action):
4576

4677
assert len(envs) == batch_size, "Number of parallel environments must be equal to the batch size."

0 commit comments

Comments
 (0)