Skip to content

Commit 50577b0

Browse files
author
Jan Michelfeit
committed
#641 code review: replace RunningNorm with NormalizedRewardNet
1 parent c80fb80 commit 50577b0

File tree

4 files changed

+139
-60
lines changed

4 files changed

+139
-60
lines changed

src/imitation/algorithms/pebble/entropy_reward.py

Lines changed: 111 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Reward function for the PEBBLE training algorithm."""
22

33
import enum
4-
from typing import Dict, Optional, Tuple, Union
4+
from typing import Optional, Tuple
55

6+
import gym
67
import numpy as np
78
import torch as th
89

@@ -12,6 +13,7 @@
1213
ReplayBufferView,
1314
)
1415
from imitation.rewards.reward_function import RewardFn
16+
from imitation.rewards.reward_nets import NormalizedRewardNet, RewardNet
1517
from imitation.util import util
1618
from imitation.util.networks import RunningNorm
1719

@@ -23,6 +25,92 @@ class PebbleRewardPhase(enum.Enum):
2325
POLICY_AND_REWARD_LEARNING = enum.auto() # Learned reward
2426

2527

28+
class InsufficientObservations(RuntimeError):
29+
pass
30+
31+
32+
class EntropyRewardNet(RewardNet):
33+
def __init__(
34+
self,
35+
nearest_neighbor_k: int,
36+
replay_buffer_view: ReplayBufferView,
37+
observation_space: gym.Space,
38+
action_space: gym.Space,
39+
normalize_images: bool = True,
40+
):
41+
"""Initialize the RewardNet.
42+
43+
Args:
44+
observation_space: the observation space of the environment
45+
action_space: the action space of the environment
46+
normalize_images: whether to automatically normalize
47+
image observations to [0, 1] (from 0 to 255). Defaults to True.
48+
"""
49+
super().__init__(observation_space, action_space, normalize_images)
50+
self.nearest_neighbor_k = nearest_neighbor_k
51+
self._replay_buffer_view = replay_buffer_view
52+
53+
def set_replay_buffer(self, replay_buffer: ReplayBufferRewardWrapper):
54+
"""This method needs to be called after unpickling.
55+
56+
See also __getstate__() / __setstate__()
57+
"""
58+
assert self.observation_space == replay_buffer.observation_space
59+
assert self.action_space == replay_buffer.action_space
60+
self._replay_buffer_view = replay_buffer.buffer_view
61+
62+
def forward(
63+
self,
64+
state: th.Tensor,
65+
action: th.Tensor,
66+
next_state: th.Tensor,
67+
done: th.Tensor,
68+
) -> th.Tensor:
69+
assert (
70+
self._replay_buffer_view is not None
71+
), "Missing replay buffer (possibly after unpickle)"
72+
73+
all_observations = self._replay_buffer_view.observations
74+
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
75+
all_observations = all_observations.reshape(
76+
(-1,) + self.observation_space.shape
77+
)
78+
79+
if all_observations.shape[0] < self.nearest_neighbor_k:
80+
raise InsufficientObservations(
81+
"Insufficient observations for entropy calculation"
82+
)
83+
84+
return util.compute_state_entropy(
85+
state, all_observations, self.nearest_neighbor_k
86+
)
87+
88+
def preprocess(
89+
self,
90+
state: np.ndarray,
91+
action: np.ndarray,
92+
next_state: np.ndarray,
93+
done: np.ndarray,
94+
) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
95+
"""Override default preprocessing to avoid the default one-hot encoding.
96+
97+
We also know forward() only works with state, so no need to convert
98+
other tensors.
99+
"""
100+
state_th = util.safe_to_tensor(state).to(self.device)
101+
action_th = next_state_th = done_th = th.empty(0)
102+
return state_th, action_th, next_state_th, done_th
103+
104+
def __getstate__(self):
105+
state = self.__dict__.copy()
106+
del state["_replay_buffer_view"]
107+
return state
108+
109+
def __setstate__(self, state):
110+
self.__dict__.update(state)
111+
self._replay_buffer_view = None
112+
113+
26114
class PebbleStateEntropyReward(ReplayBufferAwareRewardFn):
27115
"""Reward function for implementation of the PEBBLE learning algorithm.
28116
@@ -59,17 +147,27 @@ def __init__(
59147
self.learned_reward_fn = learned_reward_fn
60148
self.nearest_neighbor_k = nearest_neighbor_k
61149

62-
self.entropy_stats = RunningNorm(1)
63150
self.state = PebbleRewardPhase.UNSUPERVISED_EXPLORATION
64151

65152
# These two need to be set with set_replay_buffer():
66-
self.replay_buffer_view: Optional[ReplayBufferView] = None
67-
self.obs_shape: Union[Tuple[int, ...], Dict[str, Tuple[int, ...]], None] = None
153+
self._entropy_reward_net: Optional[EntropyRewardNet] = None
154+
self._normalized_entropy_reward_net: Optional[RewardNet] = None
68155

69156
def on_replay_buffer_initialized(self, replay_buffer: ReplayBufferRewardWrapper):
70-
self.replay_buffer_view = replay_buffer.buffer_view
71-
self.obs_shape = replay_buffer.obs_shape
72-
157+
if self._normalized_entropy_reward_net is None:
158+
self._entropy_reward_net = EntropyRewardNet(
159+
nearest_neighbor_k=self.nearest_neighbor_k,
160+
replay_buffer_view=replay_buffer.buffer_view,
161+
observation_space=replay_buffer.observation_space,
162+
action_space=replay_buffer.action_space,
163+
normalize_images=False,
164+
)
165+
self._normalized_entropy_reward_net = NormalizedRewardNet(
166+
self._entropy_reward_net, RunningNorm
167+
)
168+
else:
169+
assert self._entropy_reward_net is not None
170+
self._entropy_reward_net.set_replay_buffer(replay_buffer)
73171

74172
def unsupervised_exploration_finish(self):
75173
assert self.state == PebbleRewardPhase.UNSUPERVISED_EXPLORATION
@@ -88,35 +186,15 @@ def __call__(
88186
return self.learned_reward_fn(state, action, next_state, done)
89187

90188
def _entropy_reward(self, state, action, next_state, done):
91-
if self.replay_buffer_view is None:
189+
if self._normalized_entropy_reward_net is None:
92190
raise ValueError(
93191
"Replay buffer must be supplied before entropy reward can be used",
94192
)
95-
all_observations = self.replay_buffer_view.observations
96-
# ReplayBuffer sampling flattens the venv dimension, let's adapt to that
97-
all_observations = all_observations.reshape((-1, *self.obs_shape))
98-
99-
if all_observations.shape[0] < self.nearest_neighbor_k:
193+
try:
194+
return self._normalized_entropy_reward_net.predict_processed(
195+
state, action, next_state, done, update_stats=True
196+
)
197+
except InsufficientObservations:
100198
# not enough observations to compare to, fall back to the learned function;
101199
# (falling back to a constant may also be ok)
102200
return self.learned_reward_fn(state, action, next_state, done)
103-
else:
104-
# TODO #625: deal with the conversion back and forth between np and torch
105-
entropies = util.compute_state_entropy(
106-
th.tensor(state),
107-
th.tensor(all_observations),
108-
self.nearest_neighbor_k,
109-
)
110-
111-
normalized_entropies = self.entropy_stats.forward(entropies)
112-
113-
return normalized_entropies.numpy()
114-
115-
def __getstate__(self):
116-
state = self.__dict__.copy()
117-
del state["replay_buffer_view"]
118-
return state
119-
120-
def __setstate__(self, state):
121-
self.__dict__.update(state)
122-
self.replay_buffer_view = None

src/imitation/scripts/config/train_preference_comparisons.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,4 @@ def fast():
158158
reward_trainer_kwargs = {
159159
"epochs": 1,
160160
}
161+
locals() # quieten flake8

src/imitation/util/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,15 @@ def compute_state_entropy(
384384
for idx in range(len(all_obs) // batch_size + 1):
385385
start = idx * batch_size
386386
end = (idx + 1) * batch_size
387+
all_obs_batch = all_obs[start:end]
387388
distances_tensor = th.linalg.vector_norm(
388-
obs[:, None] - all_obs[None, start:end],
389+
obs[:, None] - all_obs_batch[None, :],
389390
dim=non_batch_dimensions,
390391
ord=2,
391392
)
393+
assert distances_tensor.shape == (obs.shape[0], all_obs_batch.shape[0])
392394
dists.append(distances_tensor)
393395
all_dists = th.cat(dists, dim=1)
394396
knn_dists = th.kthvalue(all_dists, k=k + 1, dim=1).values
395397
return knn_dists
398+

tests/algorithms/pebble/test_entropy_reward.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55

66
import numpy as np
77
import torch as th
8-
from gym.spaces import Discrete
9-
8+
from gym.spaces import Discrete, Box
9+
from gym.spaces.space import Space
1010
from imitation.algorithms.pebble.entropy_reward import PebbleStateEntropyReward
1111
from imitation.policies.replay_buffer_wrapper import ReplayBufferView
1212
from imitation.util import util
1313

14-
SPACE = Discrete(4)
15-
OBS_SHAPE = (1,)
16-
PLACEHOLDER = np.empty(OBS_SHAPE)
14+
SPACE = Box(-1, 1, shape=(1,))
15+
PLACEHOLDER = np.empty(SPACE.shape)
1716

1817
BUFFER_SIZE = 20
1918
K = 4
@@ -22,30 +21,27 @@
2221

2322

2423
def test_pebble_entropy_reward_returns_entropy_for_pretraining(rng):
25-
all_observations = rng.random((BUFFER_SIZE, VENVS, *OBS_SHAPE))
24+
all_observations = rng.random((BUFFER_SIZE, VENVS) + SPACE.shape)
2625

2726
reward_fn = PebbleStateEntropyReward(Mock(), K)
2827
reward_fn.on_replay_buffer_initialized(
2928
replay_buffer_mock(
3029
ReplayBufferView(all_observations, lambda: slice(None)),
31-
OBS_SHAPE,
30+
SPACE,
3231
)
3332
)
3433

3534
# Act
36-
observations = th.rand((BATCH_SIZE, *OBS_SHAPE))
35+
observations = th.rand((BATCH_SIZE, *SPACE.shape))
3736
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
3837

3938
# Assert
4039
expected = util.compute_state_entropy(
4140
observations,
42-
all_observations.reshape(-1, *OBS_SHAPE),
41+
all_observations.reshape(-1, *SPACE.shape),
4342
K,
4443
)
45-
expected_normalized = reward_fn.entropy_stats.normalize(
46-
th.as_tensor(expected),
47-
).numpy()
48-
np.testing.assert_allclose(reward, expected_normalized)
44+
np.testing.assert_allclose(reward, expected, rtol=0.005, atol=0.005)
4945

5046

5147
def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
@@ -55,11 +51,11 @@ def test_pebble_entropy_reward_returns_normalized_values_for_pretraining():
5551
m.side_effect = lambda obs, all_obs, k: obs
5652

5753
reward_fn = PebbleStateEntropyReward(Mock(), K)
58-
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
54+
all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape))
5955
reward_fn.on_replay_buffer_initialized(
6056
replay_buffer_mock(
6157
ReplayBufferView(all_observations, lambda: slice(None)),
62-
OBS_SHAPE,
58+
SPACE,
6359
)
6460
)
6561

@@ -97,7 +93,7 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
9793
reward_fn.unsupervised_exploration_finish()
9894

9995
# Act
100-
observations = np.ones((BATCH_SIZE, *OBS_SHAPE))
96+
observations = np.ones((BATCH_SIZE, *SPACE.shape))
10197
reward = reward_fn(observations, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
10298

10399
# Assert
@@ -111,23 +107,23 @@ def test_pebble_entropy_reward_function_returns_learned_reward_after_pre_trainin
111107

112108

113109
def test_pebble_entropy_reward_can_pickle():
114-
all_observations = np.empty((BUFFER_SIZE, VENVS, *OBS_SHAPE))
110+
all_observations = np.empty((BUFFER_SIZE, VENVS, *SPACE.shape))
115111
replay_buffer = ReplayBufferView(all_observations, lambda: slice(None))
116112

117-
obs1 = np.random.rand(VENVS, *OBS_SHAPE)
113+
obs1 = np.random.rand(VENVS, *SPACE.shape)
118114
reward_fn = PebbleStateEntropyReward(reward_fn_stub, K)
119-
reward_fn.on_replay_buffer_initialized(replay_buffer_mock(replay_buffer, OBS_SHAPE))
115+
reward_fn.on_replay_buffer_initialized(replay_buffer_mock(replay_buffer, SPACE))
120116
reward_fn(obs1, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
121117

122118
# Act
123119
pickled = pickle.dumps(reward_fn)
124120
reward_fn_deserialized = pickle.loads(pickled)
125121
reward_fn_deserialized.on_replay_buffer_initialized(
126-
replay_buffer_mock(replay_buffer, OBS_SHAPE)
122+
replay_buffer_mock(replay_buffer, SPACE)
127123
)
128124

129125
# Assert
130-
obs2 = np.random.rand(VENVS, *OBS_SHAPE)
126+
obs2 = np.random.rand(VENVS, *SPACE.shape)
131127
expected_result = reward_fn(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
132128
actual_result = reward_fn_deserialized(obs2, PLACEHOLDER, PLACEHOLDER, PLACEHOLDER)
133129
np.testing.assert_allclose(actual_result, expected_result)
@@ -137,8 +133,9 @@ def reward_fn_stub(state, action, next_state, done):
137133
return state
138134

139135

140-
def replay_buffer_mock(buffer_view: ReplayBufferView, obs_shape: tuple) -> Mock:
141-
replay_buffer_mock = Mock()
142-
replay_buffer_mock.buffer_view = buffer_view
143-
replay_buffer_mock.obs_shape = obs_shape
144-
return replay_buffer_mock
136+
def replay_buffer_mock(buffer_view: ReplayBufferView, obs_space: Space) -> Mock:
137+
mock = Mock()
138+
mock.buffer_view = buffer_view
139+
mock.observation_space = obs_space
140+
mock.action_space = SPACE
141+
return mock

0 commit comments

Comments
 (0)