Skip to content

Commit 1d257b5

Browse files
authored
Add observation_space override and float32 cast to vector NormalizeOb… (#1528)
1 parent 4663578 commit 1d257b5

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

gymnasium/wrappers/vector/stateful_observation.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import gymnasium as gym
1313
from gymnasium.core import ObsType
1414
from gymnasium.logger import warn
15+
from gymnasium.spaces import Box
16+
from gymnasium.vector.utils import batch_space
1517
from gymnasium.vector.vector_env import (
1618
AutoresetMode,
1719
VectorEnv,
@@ -78,6 +80,15 @@ def __init__(self, env: VectorEnv, epsilon: float = 1e-8):
7880
else:
7981
assert self.env.metadata["autoreset_mode"] in {AutoresetMode.NEXT_STEP}
8082

83+
new_single_space = Box(
84+
low=-np.inf,
85+
high=np.inf,
86+
shape=self.single_observation_space.shape,
87+
dtype=np.float32,
88+
)
89+
self.single_observation_space = new_single_space
90+
self.observation_space = batch_space(new_single_space, self.num_envs)
91+
8192
self.obs_rms = RunningMeanStd(
8293
shape=self.single_observation_space.shape,
8394
dtype=self.single_observation_space.dtype,
@@ -120,6 +131,7 @@ def observations(self, observations: ObsType) -> ObsType:
120131
"""
121132
if self._update_running_mean:
122133
self.obs_rms.update(observations)
123-
return (observations - self.obs_rms.mean) / np.sqrt(
124-
self.obs_rms.var + self.epsilon
125-
)
134+
return (
135+
(observations - self.obs_rms.mean)
136+
/ np.sqrt(self.obs_rms.var + self.epsilon)
137+
).astype(np.float32)

tests/wrappers/vector/test_normalize_observation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,18 @@ def test_update_running_mean():
100100

101101
assert np.any(copied_rms_mean != env.obs_rms.mean)
102102
assert np.any(copied_rms_var != env.obs_rms.var)
103+
104+
105+
def test_observation_space_and_dtype():
106+
vec_env = SyncVectorEnv([create_env for _ in range(2)])
107+
vec_env = wrappers.vector.NormalizeObservation(vec_env)
108+
109+
assert vec_env.single_observation_space.dtype == np.float32
110+
assert np.all(vec_env.single_observation_space.low == -np.inf)
111+
assert np.all(vec_env.single_observation_space.high == np.inf)
112+
113+
obs, _ = vec_env.reset(seed=123)
114+
assert obs.dtype == np.float32
115+
116+
obs, *_ = vec_env.step(vec_env.action_space.sample())
117+
assert obs.dtype == np.float32

0 commit comments

Comments
 (0)