Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def evaluate_policy(
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
new_observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
Expand All @@ -100,6 +100,7 @@ def evaluate_policy(

if callback is not None:
callback(locals(), globals())
observations = new_observations

if dones[i]:
if is_monitor_wrapped:
Expand Down