Skip to content

Commit 309f4c4

Browse files
authored
cleaning up
1 parent 133f423 commit 309f4c4

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

mitdeeplearning/lab3.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,9 @@ def play_video(filename):
1414

1515
return embedded
1616

17-
def preprocess_pong(image):
18-
I = image[35:195] # Crop
19-
I = I[::2, ::2, 0] # Downsample width and height by a factor of 2
20-
I[I == 144] = 0 # Remove background type 1
21-
I[I == 109] = 0 # Remove background type 2
22-
I[I != 0] = 1 # Set remaining elements (paddles, ball, etc.) to 1
23-
return I.astype(np.float).reshape(80, 80, 1)
2417

25-
def new_preprocess_pong(image):
18+
def preprocess_pong(image):
2619
I = image[35:195] # Crop
27-
# I = np.mean(I, axis=-1, keepdim=True)
2820
I = I[::2, ::2, 0] # Downsample width and height by a factor of 2
2921
I[I == 144] = 0 # Remove background type 1
3022
I[I == 109] = 0 # Remove background type 2
@@ -41,6 +33,7 @@ def pong_change(prev, curr):
4133
I = (I - I.min()) / (I.max() - I.min() + 1e-10)
4234
return I
4335

36+
4437
class Memory:
4538
def __init__(self):
4639
self.clear()
@@ -54,15 +47,10 @@ def clear(self):
5447
# Add observations, actions, rewards to memory
5548
def add_to_memory(self, new_observation, new_action, new_reward):
5649
self.observations.append(new_observation)
57-
'''TODO: update the list of actions with new action'''
58-
self.actions.append(new_action) # TODO
59-
# ['''TODO''']
60-
'''TODO: update the list of rewards with new reward'''
61-
self.rewards.append(new_reward) # TODO
62-
# ['''TODO''']
63-
64-
# Helper function to combine a list of Memory objects into a single Memory.
65-
# This will be very useful for batching.
50+
self.actions.append(new_action)
51+
self.rewards.append(new_reward)
52+
53+
6654
def aggregate_memories(memories):
6755
batch_memory = Memory()
6856

@@ -72,6 +60,7 @@ def aggregate_memories(memories):
7260

7361
return batch_memory
7462

63+
7564
def parallelized_collect_rollout(batch_size, envs, model, choose_action):
7665

7766
assert len(envs) == batch_size, "Number of parallel environments must be equal to the batch size."

0 commit comments

Comments
 (0)