diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a53fd999..113ce3f9 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -2,6 +2,7 @@ Changelog ========== +- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit) Release 1.8.0 (2023-04-07) -------------------------- @@ -418,4 +419,4 @@ Contributors: ------------- @ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec -@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher +@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @npit diff --git a/sb3_contrib/common/maskable/evaluation.py b/sb3_contrib/common/maskable/evaluation.py index 8a9fba83..0330c28b 100644 --- a/sb3_contrib/common/maskable/evaluation.py +++ b/sb3_contrib/common/maskable/evaluation.py @@ -103,7 +103,7 @@ def evaluate_policy( actions, states = model.predict( observations, state=states, episode_start=episode_starts, deterministic=deterministic ) - observations, rewards, dones, infos = env.step(actions) + new_observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 for i in range(n_envs): @@ -116,6 +116,7 @@ def evaluate_policy( if callback is not None: callback(locals(), globals()) + observations = new_observations if dones[i]: if is_monitor_wrapped: