-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenv_wrappers.py
More file actions
148 lines (124 loc) · 5.84 KB
/
env_wrappers.py
File metadata and controls
148 lines (124 loc) · 5.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gymnasium as gym
import numpy as np
class Float64ToFloat32ActionWrapper(gym.ActionWrapper):
"""Converts float64 actions to float32."""
def __init__(self, env):
super().__init__(env)
# Modify the action space dtype to float32
if isinstance(self.env.action_space, gym.spaces.Box):
self.action_space = gym.spaces.Box(
low=self.env.action_space.low,
high=self.env.action_space.high,
shape=self.env.action_space.shape,
dtype=np.float32,
)
def action(self, action):
# Convert the float32 action back to float64 for the underlying environment
return action.astype(np.float64)
class DiscretizeActionWrapper(gym.ActionWrapper):
"""
Discretizes a continuous Box action space into a MultiDiscrete space.
Allows specifying the number of discrete bins for each dimension.
"""
def __init__(self, env, n_bins_per_dim=2):
"""
Initializes the wrapper.
Args:
env: The environment to wrap.
n_bins_per_dim: An integer or a list/tuple specifying the number
of discrete bins for each action dimension. If an
integer, the same number of bins is used for all
dimensions.
"""
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Box), "This wrapper only works with Box action spaces."
self.original_action_space = env.action_space
self.low = self.original_action_space.low
self.high = self.original_action_space.high
action_dims = self.original_action_space.shape[0]
if isinstance(n_bins_per_dim, int):
self.n_bins = [n_bins_per_dim] * action_dims
elif isinstance(n_bins_per_dim, (list, tuple)):
assert len(n_bins_per_dim) == action_dims, \
f"Length of n_bins_per_dim ({len(n_bins_per_dim)}) must match action dimensions ({action_dims})."
self.n_bins = list(n_bins_per_dim)
else:
raise ValueError("n_bins_per_dim must be an integer or a list/tuple.")
assert all(bins >= 1 for bins in self.n_bins), "Number of bins must be at least 1 for each dimension."
self.action_space = gym.spaces.MultiDiscrete(self.n_bins)
# Pre-calculate steps for mapping discrete actions to continuous values
self.steps = []
for i in range(action_dims):
if self.n_bins[i] > 1:
step = (self.high[i] - self.low[i]) / (self.n_bins[i] - 1)
else:
# If only 1 bin, map to the lower bound (or middle, depending on desired behavior)
step = 0 # Or potentially (self.high[i] - self.low[i]) / 2 for middle
self.steps.append(step)
def action(self, action):
"""Maps the discrete action back to the continuous space."""
assert self.action_space.contains(action), f"Action {action} is invalid for space {self.action_space}"
continuous_action = np.zeros_like(self.original_action_space.low, dtype=self.original_action_space.dtype)
for i, act_dim in enumerate(action):
# Map discrete action act_dim (0 to n_bins[i]-1) to continuous value
continuous_action[i] = self.low[i] + act_dim * self.steps[i]
# Clip to ensure the action is within the original bounds due to potential floating point inaccuracies
return np.clip(continuous_action, self.low, self.high)
class DiscreteToOneHotObservationWrapper(gym.ObservationWrapper):
"""
Converts a Discrete observation space to a Box observation space using one-hot encoding.
"""
def __init__(self, env):
"""
Initializes the wrapper.
Args:
env: The environment to wrap.
"""
super().__init__(env)
assert isinstance(env.observation_space, gym.spaces.Discrete), \
"This wrapper only works with Discrete observation spaces."
self.n = env.observation_space.n
# Create a Box space with shape (n,) representing one-hot encoding
self.observation_space = gym.spaces.Box(
low=0,
high=1,
shape=(self.n,),
dtype=np.float32
)
def observation(self, observation):
"""Converts the discrete observation to one-hot encoding."""
# Create one-hot encoded vector
one_hot = np.zeros(self.n, dtype=np.float32)
one_hot[observation] = 1.0
return one_hot
class MinAtarStateWrapper(gym.ObservationWrapper):
"""
Simple wrapper for MinAtar environments that applies state transformation.
Converts observations from (H, W, C) to (C, H, W) format and adds batch dimension.
Equivalent to the get_state function but without PyTorch dependency.
"""
def __init__(self, env):
"""
Args:
env: MinAtar environment to wrap
"""
super().__init__(env)
# Update observation space to reflect the transformed shape
if isinstance(env.observation_space, gym.spaces.Box):
original_shape = env.observation_space.shape
if len(original_shape) == 3: # (H, W, C) -> (C, H, W)
new_shape = (original_shape[2], original_shape[0], original_shape[1])
self.observation_space = gym.spaces.Box(
low=float(env.observation_space.low.min()),
high=float(env.observation_space.high.max()),
shape=new_shape,
dtype=np.float32
)
def observation(self, observation):
"""
Apply the state transformation to the observation.
Converts from (H, W, C) to (C, H, W) format as float32.
"""
obs = observation.astype(np.float32)
obs = np.transpose(obs, (2, 0, 1))
return obs