Skip to content

Commit b9e932a

Browse files
William Feduspsc-g
authored andcommitted
Explicitly set the terminal_dtype in the replay memories.
PiperOrigin-RevId: 256687171
1 parent 491e786 commit b9e932a

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

dopamine/replay_memory/circular_replay_buffer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(self,
105105
max_sample_attempts=1000,
106106
extra_storage_types=None,
107107
observation_dtype=np.uint8,
108+
terminal_dtype=np.uint8,
108109
action_shape=(),
109110
action_dtype=np.int32,
110111
reward_shape=(),
@@ -124,6 +125,8 @@ def __init__(self,
124125
contents that will be stored and returned by sample_transition_batch.
125126
observation_dtype: np.dtype, type of the observations. Defaults to
126127
np.uint8 for Atari 2600.
128+
terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for
129+
Atari 2600.
127130
action_shape: tuple of ints, the shape for the action vector. Empty tuple
128131
means the action is a scalar.
129132
action_dtype: np.dtype, type of elements in the action.
@@ -145,6 +148,7 @@ def __init__(self,
145148
self.__class__.__name__)
146149
tf.logging.info('\t observation_shape: %s', str(observation_shape))
147150
tf.logging.info('\t observation_dtype: %s', str(observation_dtype))
151+
tf.logging.info('\t terminal_dtype: %s', str(terminal_dtype))
148152
tf.logging.info('\t stack_size: %d', stack_size)
149153
tf.logging.info('\t replay_capacity: %d', replay_capacity)
150154
tf.logging.info('\t batch_size: %d', batch_size)
@@ -163,6 +167,7 @@ def __init__(self,
163167
self._update_horizon = update_horizon
164168
self._gamma = gamma
165169
self._observation_dtype = observation_dtype
170+
self._terminal_dtype = terminal_dtype
166171
self._max_sample_attempts = max_sample_attempts
167172
if extra_storage_types:
168173
self._extra_storage_types = extra_storage_types
@@ -210,7 +215,7 @@ def get_storage_signature(self):
210215
self._observation_dtype),
211216
ReplayElement('action', self._action_shape, self._action_dtype),
212217
ReplayElement('reward', self._reward_shape, self._reward_dtype),
213-
ReplayElement('terminal', (), np.uint8)
218+
ReplayElement('terminal', (), self._terminal_dtype)
214219
]
215220

216221
for extra_replay_element in self._extra_storage_types:
@@ -241,7 +246,7 @@ def add(self, observation, action, reward, terminal, *args):
241246
observation: np.array with shape observation_shape.
242247
action: int, the action in the transition.
243248
reward: float, the reward received in the transition.
244-
terminal: A uint8 acting as a boolean indicating whether the transition
249+
terminal: np.dtype, acts as a boolean indicating whether the transition
245250
was terminal (1) or not (0).
246251
*args: extra contents with shapes and dtypes according to
247252
extra_storage_types.
@@ -555,7 +560,7 @@ def get_transition_elements(self, batch_size=None):
555560
self._action_dtype),
556561
ReplayElement('next_reward', (batch_size,) + self._reward_shape,
557562
self._reward_dtype),
558-
ReplayElement('terminal', (batch_size,), np.uint8),
563+
ReplayElement('terminal', (batch_size,), self._terminal_dtype),
559564
ReplayElement('indices', (batch_size,), np.int32)
560565
]
561566
for element in self._extra_storage_types:
@@ -687,6 +692,7 @@ def __init__(self,
687692
max_sample_attempts=1000,
688693
extra_storage_types=None,
689694
observation_dtype=np.uint8,
695+
terminal_dtype=np.uint8,
690696
action_shape=(),
691697
action_dtype=np.int32,
692698
reward_shape=(),
@@ -710,6 +716,8 @@ def __init__(self,
710716
contents that will be stored and returned by sample_transition_batch.
711717
observation_dtype: np.dtype, type of the observations. Defaults to
712718
np.uint8 for Atari 2600.
719+
terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for
720+
Atari 2600.
713721
action_shape: tuple of ints, the shape for the action vector. Empty tuple
714722
means the action is a scalar.
715723
action_dtype: np.dtype, type of elements in the action.
@@ -745,6 +753,7 @@ def __init__(self,
745753
gamma,
746754
max_sample_attempts,
747755
observation_dtype=observation_dtype,
756+
terminal_dtype=terminal_dtype,
748757
extra_storage_types=extra_storage_types,
749758
action_shape=action_shape,
750759
action_dtype=action_dtype,
@@ -765,7 +774,7 @@ def add(self, observation, action, reward, terminal, *args):
765774
observation: np.array with shape observation_shape.
766775
action: int, the action in the transition.
767776
reward: float, the reward received in the transition.
768-
terminal: A uint8 acting as a boolean indicating whether the transition
777+
terminal: np.dtype, acts as a boolean indicating whether the transition
769778
was terminal (1) or not (0).
770779
*args: extra contents with shapes and dtypes according to
771780
extra_storage_types.

dopamine/replay_memory/prioritized_replay_buffer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(self,
5050
max_sample_attempts=1000,
5151
extra_storage_types=None,
5252
observation_dtype=np.uint8,
53+
terminal_dtype=np.uint8,
5354
action_shape=(),
5455
action_dtype=np.int32,
5556
reward_shape=(),
@@ -69,6 +70,8 @@ def __init__(self,
6970
contents that will be stored and returned by sample_transition_batch.
7071
observation_dtype: np.dtype, type of the observations. Defaults to
7172
np.uint8 for Atari 2600.
73+
terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for
74+
Atari 2600.
7275
action_shape: tuple of ints, the shape for the action vector. Empty tuple
7376
means the action is a scalar.
7477
action_dtype: np.dtype, type of elements in the action.
@@ -86,6 +89,7 @@ def __init__(self,
8689
max_sample_attempts=max_sample_attempts,
8790
extra_storage_types=extra_storage_types,
8891
observation_dtype=observation_dtype,
92+
terminal_dtype=terminal_dtype,
8993
action_shape=action_shape,
9094
action_dtype=action_dtype,
9195
reward_shape=reward_shape,
@@ -274,6 +278,7 @@ def __init__(self,
274278
max_sample_attempts=1000,
275279
extra_storage_types=None,
276280
observation_dtype=np.uint8,
281+
terminal_dtype=np.uint8,
277282
action_shape=(),
278283
action_dtype=np.int32,
279284
reward_shape=(),
@@ -295,6 +300,8 @@ def __init__(self,
295300
contents that will be stored and returned by sample_transition_batch.
296301
observation_dtype: np.dtype, type of the observations. Defaults to
297302
np.uint8 for Atari 2600.
303+
terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for
304+
Atari 2600.
298305
action_shape: tuple of ints, the shape for the action vector. Empty tuple
299306
means the action is a scalar.
300307
action_dtype: np.dtype, type of elements in the action.
@@ -322,6 +329,7 @@ def __init__(self,
322329
wrapped_memory=memory,
323330
extra_storage_types=extra_storage_types,
324331
observation_dtype=observation_dtype,
332+
terminal_dtype=terminal_dtype,
325333
action_shape=action_shape,
326334
action_dtype=action_dtype,
327335
reward_shape=reward_shape,

tests/dopamine/replay_memory/circular_replay_buffer_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ def testConstructor(self):
7979
batch_size=BATCH_SIZE)
8080
self.assertEqual(memory._observation_shape, (4, 20))
8181
self.assertEqual(memory.add_count, 0)
82+
# Test with terminal datatype of np.int32
83+
memory = circular_replay_buffer.OutOfGraphReplayBuffer(
84+
observation_shape=OBSERVATION_SHAPE,
85+
stack_size=STACK_SIZE,
86+
terminal_dtype=np.int32,
87+
replay_capacity=5,
88+
batch_size=BATCH_SIZE)
89+
self.assertEqual(memory._terminal_dtype, np.int32)
8290

8391
def testAdd(self):
8492
memory = circular_replay_buffer.OutOfGraphReplayBuffer(

0 commit comments

Comments
 (0)