-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patharcdreaemr.py
More file actions
38 lines (32 loc) · 1.11 KB
/
arcdreaemr.py
File metadata and controls
38 lines (32 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
from gym import ActionWrapper, ObservationWrapper, RewardWrapper, Wrapper
import gymnasium as gym
from gymnasium.spaces import Box, Discrete
class arcwrapper(gym.ObservationWrapper):
def __init__(self, env: gym.Env):
gym.ObservationWrapper.__init__(self, env)
self.observation_space = gym.spaces.Dict({
**env.observation_space,
"is_first": gym.spaces.Box(0, 1, (1,), dtype=np.uint8),
"is_last": gym.spaces.Box(0, 1, (1,), dtype=np.uint8),
"is_terminal": gym.spaces.Box(0, 1, (1,), dtype=np.uint8),
}
)
def step(self, action):
obs, reward, done, truncated, info = self.env.step(action)
obs = {
**obs,
"is_first": False,
"is_last": done,
"is_terminal": done,
}
return obs, reward, done, truncated, info
def reset(self, **kwargs):
obs, info = self.env.reset()
obs = {
**obs,
"is_first": True,
"is_last": False,
"is_terminal": False,
}
return obs, info