11"""Reward function for the PEBBLE training algorithm."""
22
33import enum
4- from typing import Dict , Optional , Tuple , Union
4+ from typing import Optional , Tuple
55
6+ import gym
67import numpy as np
78import torch as th
89
1213 ReplayBufferView ,
1314)
1415from imitation .rewards .reward_function import RewardFn
16+ from imitation .rewards .reward_nets import NormalizedRewardNet , RewardNet
1517from imitation .util import util
1618from imitation .util .networks import RunningNorm
1719
@@ -23,6 +25,92 @@ class PebbleRewardPhase(enum.Enum):
2325 POLICY_AND_REWARD_LEARNING = enum .auto () # Learned reward
2426
2527
28+ class InsufficientObservations (RuntimeError ):
29+ pass
30+
31+
32+ class EntropyRewardNet (RewardNet ):
33+ def __init__ (
34+ self ,
35+ nearest_neighbor_k : int ,
36+ replay_buffer_view : ReplayBufferView ,
37+ observation_space : gym .Space ,
38+ action_space : gym .Space ,
39+ normalize_images : bool = True ,
40+ ):
41+ """Initialize the RewardNet.
42+
43+ Args:
44+ observation_space: the observation space of the environment
45+ action_space: the action space of the environment
46+ normalize_images: whether to automatically normalize
47+ image observations to [0, 1] (from 0 to 255). Defaults to True.
48+ """
49+ super ().__init__ (observation_space , action_space , normalize_images )
50+ self .nearest_neighbor_k = nearest_neighbor_k
51+ self ._replay_buffer_view = replay_buffer_view
52+
53+ def set_replay_buffer (self , replay_buffer : ReplayBufferRewardWrapper ):
54+ """This method needs to be called after unpickling.
55+
56+ See also __getstate__() / __setstate__()
57+ """
58+ assert self .observation_space == replay_buffer .observation_space
59+ assert self .action_space == replay_buffer .action_space
60+ self ._replay_buffer_view = replay_buffer .buffer_view
61+
62+ def forward (
63+ self ,
64+ state : th .Tensor ,
65+ action : th .Tensor ,
66+ next_state : th .Tensor ,
67+ done : th .Tensor ,
68+ ) -> th .Tensor :
69+ assert (
70+ self ._replay_buffer_view is not None
71+ ), "Missing replay buffer (possibly after unpickle)"
72+
73+ all_observations = self ._replay_buffer_view .observations
74+ # ReplayBuffer sampling flattens the venv dimension, let's adapt to that
75+ all_observations = all_observations .reshape (
76+ (- 1 ,) + self .observation_space .shape
77+ )
78+
79+ if all_observations .shape [0 ] < self .nearest_neighbor_k :
80+ raise InsufficientObservations (
81+ "Insufficient observations for entropy calculation"
82+ )
83+
84+ return util .compute_state_entropy (
85+ state , all_observations , self .nearest_neighbor_k
86+ )
87+
88+ def preprocess (
89+ self ,
90+ state : np .ndarray ,
91+ action : np .ndarray ,
92+ next_state : np .ndarray ,
93+ done : np .ndarray ,
94+ ) -> Tuple [th .Tensor , th .Tensor , th .Tensor , th .Tensor ]:
95+ """Override default preprocessing to avoid the default one-hot encoding.
96+
97+ We also know forward() only works with state, so no need to convert
98+ other tensors.
99+ """
100+ state_th = util .safe_to_tensor (state ).to (self .device )
101+ action_th = next_state_th = done_th = th .empty (0 )
102+ return state_th , action_th , next_state_th , done_th
103+
104+ def __getstate__ (self ):
105+ state = self .__dict__ .copy ()
106+ del state ["_replay_buffer_view" ]
107+ return state
108+
109+ def __setstate__ (self , state ):
110+ self .__dict__ .update (state )
111+ self ._replay_buffer_view = None
112+
113+
26114class PebbleStateEntropyReward (ReplayBufferAwareRewardFn ):
27115 """Reward function for implementation of the PEBBLE learning algorithm.
28116
@@ -59,17 +147,27 @@ def __init__(
59147 self .learned_reward_fn = learned_reward_fn
60148 self .nearest_neighbor_k = nearest_neighbor_k
61149
62- self .entropy_stats = RunningNorm (1 )
63150 self .state = PebbleRewardPhase .UNSUPERVISED_EXPLORATION
64151
65152 # These two need to be set with set_replay_buffer():
66- self .replay_buffer_view : Optional [ReplayBufferView ] = None
67- self .obs_shape : Union [ Tuple [ int , ...], Dict [ str , Tuple [ int , ...]], None ] = None
153+ self ._entropy_reward_net : Optional [EntropyRewardNet ] = None
154+ self ._normalized_entropy_reward_net : Optional [ RewardNet ] = None
68155
69156 def on_replay_buffer_initialized (self , replay_buffer : ReplayBufferRewardWrapper ):
70- self .replay_buffer_view = replay_buffer .buffer_view
71- self .obs_shape = replay_buffer .obs_shape
72-
157+ if self ._normalized_entropy_reward_net is None :
158+ self ._entropy_reward_net = EntropyRewardNet (
159+ nearest_neighbor_k = self .nearest_neighbor_k ,
160+ replay_buffer_view = replay_buffer .buffer_view ,
161+ observation_space = replay_buffer .observation_space ,
162+ action_space = replay_buffer .action_space ,
163+ normalize_images = False ,
164+ )
165+ self ._normalized_entropy_reward_net = NormalizedRewardNet (
166+ self ._entropy_reward_net , RunningNorm
167+ )
168+ else :
169+ assert self ._entropy_reward_net is not None
170+ self ._entropy_reward_net .set_replay_buffer (replay_buffer )
73171
74172 def unsupervised_exploration_finish (self ):
75173 assert self .state == PebbleRewardPhase .UNSUPERVISED_EXPLORATION
@@ -88,35 +186,15 @@ def __call__(
88186 return self .learned_reward_fn (state , action , next_state , done )
89187
90188 def _entropy_reward (self , state , action , next_state , done ):
91- if self .replay_buffer_view is None :
189+ if self ._normalized_entropy_reward_net is None :
92190 raise ValueError (
93191 "Replay buffer must be supplied before entropy reward can be used" ,
94192 )
95- all_observations = self . replay_buffer_view . observations
96- # ReplayBuffer sampling flattens the venv dimension, let's adapt to that
97- all_observations = all_observations . reshape (( - 1 , * self . obs_shape ))
98-
99- if all_observations . shape [ 0 ] < self . nearest_neighbor_k :
193+ try :
194+ return self . _normalized_entropy_reward_net . predict_processed (
195+ state , action , next_state , done , update_stats = True
196+ )
197+ except InsufficientObservations :
100198 # not enough observations to compare to, fall back to the learned function;
101199 # (falling back to a constant may also be ok)
102200 return self .learned_reward_fn (state , action , next_state , done )
103- else :
104- # TODO #625: deal with the conversion back and forth between np and torch
105- entropies = util .compute_state_entropy (
106- th .tensor (state ),
107- th .tensor (all_observations ),
108- self .nearest_neighbor_k ,
109- )
110-
111- normalized_entropies = self .entropy_stats .forward (entropies )
112-
113- return normalized_entropies .numpy ()
114-
115- def __getstate__ (self ):
116- state = self .__dict__ .copy ()
117- del state ["replay_buffer_view" ]
118- return state
119-
120- def __setstate__ (self , state ):
121- self .__dict__ .update (state )
122- self .replay_buffer_view = None
0 commit comments