Skip to content

Commit cb87105

Browse files
committed
Work towards simplifying ReplayMemory
1 parent c6172f3 commit cb87105

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed

rltorch/memory/SimplifiedMemory.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from random import sample
2+
from collections import deque
3+
4+
class ReplayMemory:
5+
"""
6+
Creates a queue of a fixed size.
7+
8+
Parameters
9+
----------
10+
capacity : int
11+
The maximum size of the buffer
12+
"""
13+
def __init__(self, capacity):
14+
self.capacity = capacity
15+
self.memory = deque(maxlen=capacity)
16+
17+
def append(self, **kwargs):
18+
"""
19+
Adds a transition to the buffer.
20+
21+
Parameters
22+
----------
23+
**kwargs
24+
The state, action, reward, next_state, done tuple
25+
"""
26+
self.memory.append(kwargs)
27+
28+
def clear(self):
29+
"""
30+
Clears the buffer.
31+
"""
32+
self.memory.clear()
33+
34+
def _encode_sample(self, indices):
35+
batch = list()
36+
for i in indices:
37+
batch.append(self.memory[i])
38+
return batch
39+
40+
def sample(self, batch_size):
41+
"""
42+
Returns a random sample from the buffer.
43+
44+
Parameters
45+
----------
46+
batch_size : int
47+
The number of observations to sample.
48+
"""
49+
return sample(self.memory, batch_size)
50+
51+
def sample_n_steps(self, batch_size, steps):
52+
r"""
53+
Returns a random sample of sequential batches of size steps.
54+
55+
Notes
56+
-----
57+
The number of batches sampled is :math:`\lfloor\frac{batch\_size}{steps}\rfloor`.
58+
59+
Parameters
60+
----------
61+
batch_size : int
62+
The total number of observations to sample.
63+
steps : int
64+
The number of observations after the one selected to sample.
65+
"""
66+
idxes = sample(
67+
range(len(self.memory) - steps),
68+
batch_size // steps
69+
)
70+
step_idxes = []
71+
for i in idxes:
72+
step_idxes += range(i, i + steps)
73+
return self._encode_sample(step_idxes)
74+
75+
def __len__(self):
76+
return len(self.memory)

0 commit comments

Comments
 (0)