Skip to content

Commit 531b353

Browse files
author
Jan Michelfeit
committed
#641 code review: refactor PebbleStateEntropyReward so that inner RewardNets can be injected from the outside
1 parent 50577b0 commit 531b353

File tree

6 files changed

+241
-143
lines changed

6 files changed

+241
-143
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 34 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,47 +13,46 @@
1313
ReplayBufferView,
1414
)
1515
from imitation.rewards.reward_function import RewardFn
16-
from imitation.rewards.reward_nets import NormalizedRewardNet, RewardNet
16+
from imitation.rewards.reward_nets import RewardNet
1717
from imitation.util import util
18-
from imitation.util.networks import RunningNorm
19-
20-
21-
class PebbleRewardPhase(enum.Enum):
22-
"""States representing different behaviors for PebbleStateEntropyReward."""
23-
24-
UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
25-
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward
2618

2719

2820
class InsufficientObservations(RuntimeError):
2921
pass
3022

3123

32-
class EntropyRewardNet(RewardNet):
24+
class EntropyRewardNet(RewardNet, ReplayBufferAwareRewardFn):
3325
def __init__(
3426
self,
3527
nearest_neighbor_k: int,
36-
replay_buffer_view: ReplayBufferView,
3728
observation_space: gym.Space,
3829
action_space: gym.Space,
3930
normalize_images: bool = True,
31+
replay_buffer_view: Optional[ReplayBufferView] = None,
4032
):
4133
"""Initialize the RewardNet.
4234
4335
Args:
36+
nearest_neighbor_k: Parameter for entropy computation (see
37+
compute_state_entropy())
4438
observation_space: the observation space of the environment
4539
action_space: the action space of the environment
4640
normalize_images: whether to automatically normalize
4741
image observations to [0, 1] (from 0 to 255). Defaults to True.
42+
replay_buffer_view: Replay buffer view with observations to compare
43+
against when computing entropy. If None is given, the buffer needs to
44+
be set with on_replay_buffer_initialized() before EntropyRewardNet can
45+
be used
4846
"""
4947
super().__init__(observation_space, action_space, normalize_images)
5048
self.nearest_neighbor_k = nearest_neighbor_k
5149
self._replay_buffer_view = replay_buffer_view
5250

53-
def set_replay_buffer(self, replay_buffer: ReplayBufferRewardWrapper):
54-
"""This method needs to be called after unpickling.
51+
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
52+
"""Sets replay buffer.
5553
56-
See also __getstate__() / __setstate__()
54+
This method needs to be called, e.g., after unpickling.
55+
See also __getstate__() / __setstate__().
5756
"""
5857
assert self.observation_space == replay_buffer.observation_space
5958
assert self.action_space == replay_buffer.action_space
@@ -111,6 +110,13 @@ def __setstate__(self, state):
111110
self._replay_buffer_view = None
112111

113112

113+
class PebbleRewardPhase(enum.Enum):
114+
"""States representing different behaviors for PebbleStateEntropyReward."""
115+
116+
UNSUPERVISED_EXPLORATION = enum.auto() # Entropy based reward
117+
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward
118+
119+
114120
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
115121
"""Reward function for implementation of the PEBBLE learning algorithm.
116122
@@ -126,48 +132,30 @@ class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
126132
reward is returned.
127133
128134
The second phase requires that a buffer with observations to compare against is
129-
supplied with set_replay_buffer() or on_replay_buffer_initialized().
130-
To transition to the last phase, unsupervised_exploration_finish() needs
131-
to be called.
135+
supplied with on_replay_buffer_initialized(). To transition to the last phase,
136+
unsupervised_exploration_finish() needs to be called.
132137
"""
133138

134139
def __init__(
135140
self,
141+
entropy_reward_fn: RewardFn,
136142
learned_reward_fn: RewardFn,
137-
nearest_neighbor_k: int = 5,
138143
):
139144
"""Builds this class.
140145
141146
Args:
147+
entropy_reward_fn: The entropy-based reward function used during
148+
unsupervised exploration
142149
learned_reward_fn: The learned reward function used after unsupervised
143150
exploration is finished
144-
nearest_neighbor_k: Parameter for entropy computation (see
145-
compute_state_entropy())
146151
"""
152+
self.entropy_reward_fn = entropy_reward_fn
147153
self.learned_reward_fn = learned_reward_fn
148-
self.nearest_neighbor_k = nearest_neighbor_k
149-
150154
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION
151155

152-
# These two need to be set with set_replay_buffer():
153-
self._entropy_reward_net: Optional[EntropyRewardNet] = None
154-
self._normalized_entropy_reward_net: Optional[RewardNet] = None
155-
156156
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
157-
if self._normalized_entropy_reward_net is None:
158-
self._entropy_reward_net = EntropyRewardNet(
159-
nearest_neighbor_k=self.nearest_neighbor_k,
160-
replay_buffer_view=replay_buffer.buffer_view,
161-
observation_space=replay_buffer.observation_space,
162-
action_space=replay_buffer.action_space,
163-
normalize_images=False,
164-
)
165-
self._normalized_entropy_reward_net = NormalizedRewardNet(
166-
self._entropy_reward_net, RunningNorm
167-
)
168-
else:
169-
assert self._entropy_reward_net is not None
170-
self._entropy_reward_net.set_replay_buffer(replay_buffer)
157+
if isinstance(self.entropy_reward_fn, ReplayBufferAwareRewardFn):
158+
self.entropy_reward_fn.on_replay_buffer_initialized(replay_buffer)
171159

172160
def unsupervised_exploration_finish(self):
173161
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
@@ -181,20 +169,11 @@ def __call__(
181169
done: np.ndarray,
182170
) -> np.ndarray:
183171
if self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION:
184-
return self._entropy_reward(state, action, next_state, done)
172+
try:
173+
return self.entropy_reward_fn(state, action, next_state, done)
174+
except InsufficientObservations:
175+
# not enough observations to compare to, fall back to the learned function;
176+
# (falling back to a constant may also be ok)
177+
return self.learned_reward_fn(state, action, next_state, done)
185178
else:
186179
return self.learned_reward_fn(state, action, next_state, done)
187-
188-
def _entropy_reward(self, state, action, next_state, done):
189-
if self._normalized_entropy_reward_net is None:
190-
raise ValueError(
191-
"Replay buffer must be supplied before entropy reward can be used",
192-
)
193-
try:
194-
return self._normalized_entropy_reward_net.predict_processed(
195-
state, action, next_state, done, update_stats=True
196-
)
197-
except InsufficientObservations:
198-
# not enough observations to compare to, fall back to the learned function;
199-
# (falling back to a constant may also be ok)
200-
return self.learned_reward_fn(state, action, next_state, done)

src/imitation/scripts/train_preference_comparisons.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,26 @@
1313
from stable_baselines3.common import base_class, type_aliases, vec_env
1414

1515
from imitation.algorithms import preference_comparisons
16-
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
16+
from imitation.algorithms.pebble.entropy_reward import (
17+
EntropyRewardNet,
18+
PebbleStateEntropyReward,
19+
)
1720
from imitation.data import types
1821
from imitation.policies import serialize
22+
from imitation.policies.replay_buffer_wrapper import (
23+
ReplayBufferAwareRewardFn,
24+
ReplayBufferRewardWrapper,
25+
)
1926
from imitation.rewards import reward_function, reward_nets
27+
from imitation.rewards.reward_nets import NormalizedRewardNet
2028
from imitation.scripts.common import common, reward
2129
from imitation.scripts.common import rl as rl_common
2230
from imitation.scripts.common import train
2331
from imitation.scripts.config.train_preference_comparisons import (
2432
train_preference_comparisons_ex,
2533
)
2634
from imitation.util import logger as imit_logger
35+
from imitation.util.networks import RunningNorm
2736

2837

2938
def save_model(
@@ -71,14 +80,47 @@ def make_reward_function(
7180
reward_net.predict_processed,
7281
update_stats=False,
7382
)
83+
observation_space = reward_net.observation_space
84+
action_space = reward_net.action_space
7485
if pebble_enabled:
75-
relabel_reward_fn = PebbleStateEntropyReward(
76-
relabel_reward_fn, # type: ignore[assignment]
86+
relabel_reward_fn = create_pebble_reward_fn(
87+
relabel_reward_fn,
7788
pebble_nearest_neighbor_k,
89+
action_space,
90+
observation_space,
7891
)
7992
return relabel_reward_fn
8093

8194

95+
def create_pebble_reward_fn(
96+
relabel_reward_fn, pebble_nearest_neighbor_k, action_space, observation_space
97+
):
98+
entropy_reward_net = EntropyRewardNet(
99+
nearest_neighbor_k=pebble_nearest_neighbor_k,
100+
observation_space=observation_space,
101+
action_space=action_space,
102+
normalize_images=False,
103+
)
104+
normalized_entropy_reward_net = NormalizedRewardNet(entropy_reward_net, RunningNorm)
105+
106+
class EntropyRewardFn(ReplayBufferAwareRewardFn):
107+
"""Adapter for entropy reward adding on_replay_buffer_initialized() hook."""
108+
109+
def __call__(self, *args, **kwargs) -> np.ndarray:
110+
kwargs["update_stats"] = True
111+
return normalized_entropy_reward_net.predict_processed(*args, **kwargs)
112+
113+
def on_replay_buffer_initialized(
114+
self, replay_buffer: ReplayBufferRewardWrapper
115+
):
116+
entropy_reward_net.on_replay_buffer_initialized(replay_buffer)
117+
118+
return PebbleStateEntropyReward(
119+
EntropyRewardFn(),
120+
relabel_reward_fn, # type: ignore[assignment]
121+
)
122+
123+
82124
@train_preference_comparisons_ex.capture
83125
def make_agent_trajectory_generator(
84126
venv: vec_env.VecEnv,

src/imitation/util/util.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,4 +395,3 @@ def compute_state_entropy(
395395
all_dists = th.cat(dists, dim=1)
396396
knn_dists = th.kthvalue(all_dists, k=k + 1, dim=1).values
397397
return knn_dists
398-

0 commit comments

Comments
 (0)