@@ -105,6 +105,7 @@ def __init__(self,
105105 max_sample_attempts = 1000 ,
106106 extra_storage_types = None ,
107107 observation_dtype = np .uint8 ,
108+ terminal_dtype = np .uint8 ,
108109 action_shape = (),
109110 action_dtype = np .int32 ,
110111 reward_shape = (),
@@ -124,6 +125,8 @@ def __init__(self,
124125 contents that will be stored and returned by sample_transition_batch.
125126 observation_dtype: np.dtype, type of the observations. Defaults to
126127 np.uint8 for Atari 2600.
128+ terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for
129+ Atari 2600.
127130 action_shape: tuple of ints, the shape for the action vector. Empty tuple
128131 means the action is a scalar.
129132 action_dtype: np.dtype, type of elements in the action.
@@ -145,6 +148,7 @@ def __init__(self,
145148 self .__class__ .__name__ )
146149 tf .logging .info ('\t observation_shape: %s' , str (observation_shape ))
147150 tf .logging .info ('\t observation_dtype: %s' , str (observation_dtype ))
151+ tf .logging .info ('\t terminal_dtype: %s' , str (terminal_dtype ))
148152 tf .logging .info ('\t stack_size: %d' , stack_size )
149153 tf .logging .info ('\t replay_capacity: %d' , replay_capacity )
150154 tf .logging .info ('\t batch_size: %d' , batch_size )
@@ -163,6 +167,7 @@ def __init__(self,
163167 self ._update_horizon = update_horizon
164168 self ._gamma = gamma
165169 self ._observation_dtype = observation_dtype
170+ self ._terminal_dtype = terminal_dtype
166171 self ._max_sample_attempts = max_sample_attempts
167172 if extra_storage_types :
168173 self ._extra_storage_types = extra_storage_types
@@ -210,7 +215,7 @@ def get_storage_signature(self):
210215 self ._observation_dtype ),
211216 ReplayElement ('action' , self ._action_shape , self ._action_dtype ),
212217 ReplayElement ('reward' , self ._reward_shape , self ._reward_dtype ),
213- ReplayElement ('terminal' , (), np . uint8 )
218+ ReplayElement ('terminal' , (), self . _terminal_dtype )
214219 ]
215220
216221 for extra_replay_element in self ._extra_storage_types :
@@ -241,7 +246,7 @@ def add(self, observation, action, reward, terminal, *args):
241246 observation: np.array with shape observation_shape.
242247 action: int, the action in the transition.
243248 reward: float, the reward received in the transition.
244- terminal: A uint8 acting as a boolean indicating whether the transition
249+ terminal: np.dtype, acts as a boolean indicating whether the transition
245250 was terminal (1) or not (0).
246251 *args: extra contents with shapes and dtypes according to
247252 extra_storage_types.
@@ -555,7 +560,7 @@ def get_transition_elements(self, batch_size=None):
555560 self ._action_dtype ),
556561 ReplayElement ('next_reward' , (batch_size ,) + self ._reward_shape ,
557562 self ._reward_dtype ),
558- ReplayElement ('terminal' , (batch_size ,), np . uint8 ),
563+ ReplayElement ('terminal' , (batch_size ,), self . _terminal_dtype ),
559564 ReplayElement ('indices' , (batch_size ,), np .int32 )
560565 ]
561566 for element in self ._extra_storage_types :
@@ -687,6 +692,7 @@ def __init__(self,
687692 max_sample_attempts = 1000 ,
688693 extra_storage_types = None ,
689694 observation_dtype = np .uint8 ,
695+ terminal_dtype = np .uint8 ,
690696 action_shape = (),
691697 action_dtype = np .int32 ,
692698 reward_shape = (),
@@ -710,6 +716,8 @@ def __init__(self,
710716 contents that will be stored and returned by sample_transition_batch.
711717 observation_dtype: np.dtype, type of the observations. Defaults to
712718 np.uint8 for Atari 2600.
719+ terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for
720+ Atari 2600.
713721 action_shape: tuple of ints, the shape for the action vector. Empty tuple
714722 means the action is a scalar.
715723 action_dtype: np.dtype, type of elements in the action.
@@ -745,6 +753,7 @@ def __init__(self,
745753 gamma ,
746754 max_sample_attempts ,
747755 observation_dtype = observation_dtype ,
756+ terminal_dtype = terminal_dtype ,
748757 extra_storage_types = extra_storage_types ,
749758 action_shape = action_shape ,
750759 action_dtype = action_dtype ,
@@ -765,7 +774,7 @@ def add(self, observation, action, reward, terminal, *args):
765774 observation: np.array with shape observation_shape.
766775 action: int, the action in the transition.
767776 reward: float, the reward received in the transition.
768- terminal: A uint8 acting as a boolean indicating whether the transition
777+ terminal: np.dtype, acts as a boolean indicating whether the transition
769778 was terminal (1) or not (0).
770779 *args: extra contents with shapes and dtypes according to
771780 extra_storage_types.
0 commit comments