File tree Expand file tree Collapse file tree 1 file changed +31
-0
lines changed Expand file tree Collapse file tree 1 file changed +31
-0
lines changed Original file line number Diff line number Diff line change @@ -41,6 +41,37 @@ def pong_change(prev, curr):
41
41
I = (I - I .min ()) / (I .max () - I .min () + 1e-10 )
42
42
return I
43
43
44
+ class Memory :
45
+ def __init__ (self ):
46
+ self .clear ()
47
+
48
+ # Resets/restarts the memory buffer
49
+ def clear (self ):
50
+ self .observations = []
51
+ self .actions = []
52
+ self .rewards = []
53
+
54
+ # Add observations, actions, rewards to memory
55
+ def add_to_memory (self , new_observation , new_action , new_reward ):
56
+ self .observations .append (new_observation )
57
+ '''TODO: update the list of actions with new action'''
58
+ self .actions .append (new_action ) # TODO
59
+ # ['''TODO''']
60
+ '''TODO: update the list of rewards with new reward'''
61
+ self .rewards .append (new_reward ) # TODO
62
+ # ['''TODO''']
63
+
64
+ # Helper function to combine a list of Memory objects into a single Memory.
65
+ # This will be very useful for batching.
66
+ def aggregate_memories (memories ):
67
+ batch_memory = Memory ()
68
+
69
+ for memory in memories :
70
+ for step in zip (memory .observations , memory .actions , memory .rewards ):
71
+ batch_memory .add_to_memory (* step )
72
+
73
+ return batch_memory
74
+
44
75
def parallelized_collect_rollout (batch_size , envs , model , choose_action ):
45
76
46
77
assert len (envs ) == batch_size , "Number of parallel environments must be equal to the batch size."
You can’t perform that action at this time.
0 commit comments