1818from imitation .util .networks import RunningNorm
1919
2020
21- class PebbleRewardPhase (enum .Enum ):
22- """States representing different behaviors for PebbleStateEntropyReward."""
23-
24- UNSUPERVISED_EXPLORATION = enum .auto () # Entropy based reward
25- POLICY_AND_REWARD_LEARNING = enum .auto () # Learned reward
26-
27-
2821class InsufficientObservations (RuntimeError ):
2922 pass
3023
3124
32- class EntropyRewardNet (RewardNet ):
25+ class EntropyRewardNet (RewardNet , ReplayBufferAwareRewardFn ):
3326 def __init__ (
3427 self ,
3528 nearest_neighbor_k : int ,
36- replay_buffer_view : ReplayBufferView ,
3729 observation_space : gym .Space ,
3830 action_space : gym .Space ,
3931 normalize_images : bool = True ,
32+ replay_buffer_view : Optional [ReplayBufferView ] = None ,
4033 ):
4134 """Initialize the RewardNet.
4235
4336 Args:
37+ nearest_neighbor_k: Parameter for entropy computation (see
38+ compute_state_entropy())
4439 observation_space: the observation space of the environment
4540 action_space: the action space of the environment
4641 normalize_images: whether to automatically normalize
4742 image observations to [0, 1] (from 0 to 255). Defaults to True.
43+ replay_buffer_view: Replay buffer view with observations to compare
44+ against when computing entropy. If None is given, the buffer needs to
45+ be set with on_replay_buffer_initialized() before EntropyRewardNet can
46+ be used
4847 """
4948 super ().__init__ (observation_space , action_space , normalize_images )
5049 self .nearest_neighbor_k = nearest_neighbor_k
5150 self ._replay_buffer_view = replay_buffer_view
5251
53- def set_replay_buffer (self , replay_buffer : ReplayBufferRewardWrapper ):
54- """This method needs to be called after unpickling .
52+ def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
53+ """Sets replay buffer .
5554
56- See also __getstate__() / __setstate__()
55+ This method needs to be called, e.g., after unpickling.
56+ See also __getstate__() / __setstate__().
5757 """
5858 assert self .observation_space == replay_buffer .observation_space
5959 assert self .action_space == replay_buffer .action_space
@@ -111,6 +111,13 @@ def __setstate__(self, state):
111111 self ._replay_buffer_view = None
112112
113113
114+ class PebbleRewardPhase (enum .Enum ):
115+ """States representing different behaviors for PebbleStateEntropyReward."""
116+
117+ UNSUPERVISED_EXPLORATION = enum .auto () # Entropy based reward
118+ POLICY_AND_REWARD_LEARNING = enum .auto () # Learned reward
119+
120+
114121class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
115122 """Reward function for implementation of the PEBBLE learning algorithm.
116123
@@ -126,14 +133,15 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
126133 reward is returned.
127134
128135 The second phase requires that a buffer with observations to compare against is
129- supplied with set_replay_buffer() or on_replay_buffer_initialized().
130- To transition to the last phase, unsupervised_exploration_finish() needs
131- to be called.
136+ supplied with on_replay_buffer_initialized(). To transition to the last phase,
137+ unsupervised_exploration_finish() needs to be called.
132138 """
133139
134140 def __init__ (
135141 self ,
136142 learned_reward_fn : RewardFn ,
143+ observation_space : gym .Space ,
144+ action_space : gym .Space ,
137145 nearest_neighbor_k : int = 5 ,
138146 ):
139147 """Builds this class.
@@ -146,28 +154,20 @@ def __init__(
146154 """
147155 self .learned_reward_fn = learned_reward_fn
148156 self .nearest_neighbor_k = nearest_neighbor_k
149-
150157 self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
151158
152- # These two need to be set with set_replay_buffer():
153- self ._entropy_reward_net : Optional [EntropyRewardNet ] = None
154- self ._normalized_entropy_reward_net : Optional [RewardNet ] = None
159+ self ._entropy_reward_net = EntropyRewardNet (
160+ nearest_neighbor_k = self .nearest_neighbor_k ,
161+ observation_space = observation_space ,
162+ action_space = action_space ,
163+ normalize_images = False ,
164+ )
165+ self ._normalized_entropy_reward_net = NormalizedRewardNet (
166+ self ._entropy_reward_net , RunningNorm
167+ )
155168
156169 def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
157- if self ._normalized_entropy_reward_net is None :
158- self ._entropy_reward_net = EntropyRewardNet (
159- nearest_neighbor_k = self .nearest_neighbor_k ,
160- replay_buffer_view = replay_buffer .buffer_view ,
161- observation_space = replay_buffer .observation_space ,
162- action_space = replay_buffer .action_space ,
163- normalize_images = False ,
164- )
165- self ._normalized_entropy_reward_net = NormalizedRewardNet (
166- self ._entropy_reward_net , RunningNorm
167- )
168- else :
169- assert self ._entropy_reward_net is not None
170- self ._entropy_reward_net .set_replay_buffer (replay_buffer )
170+ self ._entropy_reward_net .on_replay_buffer_initialized (replay_buffer )
171171
172172 def unsupervised_exploration_finish (self ):
173173 assert self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION
0 commit comments