1313 ReplayBufferView ,
1414)
1515from imitation .rewards .reward_function import RewardFn
16- from imitation .rewards .reward_nets import NormalizedRewardNet , RewardNet
16+ from imitation .rewards .reward_nets import RewardNet
1717from imitation .util import util
18- from imitation .util .networks import RunningNorm
19-
20-
21- class PebbleRewardPhase (enum .Enum ):
22- """States representing different behaviors for PebbleStateEntropyReward."""
23-
24- UNSUPERVISED_EXPLORATION = enum .auto () # Entropy based reward
25- POLICY_AND_REWARD_LEARNING = enum .auto () # Learned reward
2618
2719
2820class InsufficientObservations (RuntimeError ):
2921 pass
3022
3123
32- class EntropyRewardNet (RewardNet ):
24+ class EntropyRewardNet (RewardNet , ReplayBufferAwareRewardFn ):
3325 def __init__ (
3426 self ,
3527 nearest_neighbor_k : int ,
36- replay_buffer_view : ReplayBufferView ,
3728 observation_space : gym .Space ,
3829 action_space : gym .Space ,
3930 normalize_images : bool = True ,
31+ replay_buffer_view : Optional [ReplayBufferView ] = None ,
4032 ):
4133 """Initialize the RewardNet.
4234
4335 Args:
36+ nearest_neighbor_k: Parameter for entropy computation (see
37+ compute_state_entropy())
4438 observation_space: the observation space of the environment
4539 action_space: the action space of the environment
4640 normalize_images: whether to automatically normalize
4741 image observations to [0, 1] (from 0 to 255). Defaults to True.
42+ replay_buffer_view: Replay buffer view with observations to compare
43+ against when computing entropy. If None is given, the buffer needs to
44+ be set with on_replay_buffer_initialized() before EntropyRewardNet can
45+ be used
4846 """
4947 super ().__init__ (observation_space , action_space , normalize_images )
5048 self .nearest_neighbor_k = nearest_neighbor_k
5149 self ._replay_buffer_view = replay_buffer_view
5250
53- def set_replay_buffer (self , replay_buffer : ReplayBufferRewardWrapper ):
54- """This method needs to be called after unpickling .
51+ def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
52+ """Sets replay buffer .
5553
56- See also __getstate__() / __setstate__()
54+ This method needs to be called, e.g., after unpickling.
55+ See also __getstate__() / __setstate__().
5756 """
5857 assert self .observation_space == replay_buffer .observation_space
5958 assert self .action_space == replay_buffer .action_space
@@ -111,6 +110,13 @@ def __setstate__(self, state):
111110 self ._replay_buffer_view = None
112111
113112
113+ class PebbleRewardPhase (enum .Enum ):
114+ """States representing different behaviors for PebbleStateEntropyReward."""
115+
116+ UNSUPERVISED_EXPLORATION = enum .auto () # Entropy based reward
117+ POLICY_AND_REWARD_LEARNING = enum .auto () # Learned reward
118+
119+
114120class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
115121 """Reward function for implementation of the PEBBLE learning algorithm.
116122
@@ -126,48 +132,30 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
126132 reward is returned.
127133
128134 The second phase requires that a buffer with observations to compare against is
129- supplied with set_replay_buffer() or on_replay_buffer_initialized().
130- To transition to the last phase, unsupervised_exploration_finish() needs
131- to be called.
135+ supplied with on_replay_buffer_initialized(). To transition to the last phase,
136+ unsupervised_exploration_finish() needs to be called.
132137 """
133138
134139 def __init__ (
135140 self ,
141+ entropy_reward_fn : RewardFn ,
136142 learned_reward_fn : RewardFn ,
137- nearest_neighbor_k : int = 5 ,
138143 ):
139144 """Builds this class.
140145
141146 Args:
147+ entropy_reward_fn: The entropy-based reward function used during
148+ unsupervised exploration
142149 learned_reward_fn: The learned reward function used after unsupervised
143150 exploration is finished
144- nearest_neighbor_k: Parameter for entropy computation (see
145- compute_state_entropy())
146151 """
152+ self .entropy_reward_fn = entropy_reward_fn
147153 self .learned_reward_fn = learned_reward_fn
148- self .nearest_neighbor_k = nearest_neighbor_k
149-
150154 self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
151155
152- # These two need to be set with set_replay_buffer():
153- self ._entropy_reward_net : Optional [EntropyRewardNet ] = None
154- self ._normalized_entropy_reward_net : Optional [RewardNet ] = None
155-
156156 def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
157- if self ._normalized_entropy_reward_net is None :
158- self ._entropy_reward_net = EntropyRewardNet (
159- nearest_neighbor_k = self .nearest_neighbor_k ,
160- replay_buffer_view = replay_buffer .buffer_view ,
161- observation_space = replay_buffer .observation_space ,
162- action_space = replay_buffer .action_space ,
163- normalize_images = False ,
164- )
165- self ._normalized_entropy_reward_net = NormalizedRewardNet (
166- self ._entropy_reward_net , RunningNorm
167- )
168- else :
169- assert self ._entropy_reward_net is not None
170- self ._entropy_reward_net .set_replay_buffer (replay_buffer )
157+ if isinstance (self .entropy_reward_fn , ReplayBufferAwareRewardFn ):
158+ self .entropy_reward_fn .on_replay_buffer_initialized (replay_buffer )
171159
172160 def unsupervised_exploration_finish (self ):
173161 assert self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION
@@ -181,20 +169,11 @@ def __call__(
181169 done : np .ndarray ,
182170 ) -> np .ndarray :
183171 if self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION :
184- return self ._entropy_reward (state , action , next_state , done )
172+ try :
173+ return self .entropy_reward_fn (state , action , next_state , done )
174+ except InsufficientObservations :
175+ # not enough observations to compare to, fall back to the learned function;
176+ # (falling back to a constant may also be ok)
177+ return self .learned_reward_fn (state , action , next_state , done )
185178 else :
186179 return self .learned_reward_fn (state , action , next_state , done )
187-
188- def _entropy_reward (self , state , action , next_state , done ):
189- if self ._normalized_entropy_reward_net is None :
190- raise ValueError (
191- "Replay buffer must be supplied before entropy reward can be used" ,
192- )
193- try :
194- return self ._normalized_entropy_reward_net .predict_processed (
195- state , action , next_state , done , update_stats = True
196- )
197- except InsufficientObservations :
198- # not enough observations to compare to, fall back to the learned function;
199- # (falling back to a constant may also be ok)
200- return self .learned_reward_fn (state , action , next_state , done )
0 commit comments