Skip to content

Commit 36bd1ef

Browse files
committed
Fix GymWalk env API
1 parent 06e25f1 commit 36bd1ef

14 files changed

+305
-179
lines changed

api/gym_walk_env/gym_walk.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
this is a copy verbatim of the code in that repository
33
44
"""
5-
import sys
5+
66
import numpy as np
77
from six import StringIO
8-
from string import ascii_uppercase
8+
99
from typing import Optional
1010

1111
import gymnasium as gym
12-
from gymnasium import spaces, utils
12+
from gymnasium import spaces
1313
from gymnasium.envs.toy_text.utils import categorical_sample
1414

1515
WEST, EAST = 0, 1
@@ -59,7 +59,7 @@ def __init__(self, n_states=7, p_stay=0.0, p_backward=0.5):
5959

6060
self.s = categorical_sample(self.isd, self.np_random)
6161

62-
def step(self, action: int):
62+
def step(self, action: int) -> tuple[int, float, bool, dict]:
6363
transitions = self.P[self.s][action]
6464
i = categorical_sample([t[0] for t in transitions], self.np_random)
6565
p, s, r, d = transitions[i]
@@ -75,14 +75,16 @@ def reset(self, *, seed: Optional[int] = None,
7575
self.lastaction = None
7676
return int(self.s)
7777

78-
def render(self, mode='human', close=False) -> None:
79-
outfile = StringIO() if mode == 'ansi' else sys.stdout
80-
desc = np.asarray(['[' + ascii_uppercase[:self.shape[1] - 2] + ']'], dtype='c').tolist()
81-
desc = [[c.decode('utf-8') for c in line] for line in desc]
82-
color = 'red' if self.s == 0 else 'green' if self.s == self.nS - 1 else 'yellow'
83-
desc[0][self.s] = utils.colorize(desc[0][self.s], color, highlight=True)
84-
outfile.write("\n")
85-
outfile.write("\n".join(''.join(line) for line in desc) + "\n")
86-
87-
if mode != 'human':
88-
return outfile
78+
# def render(self, mode='human', close=False) -> None:
79+
# outfile = StringIO() if mode == 'ansi' else sys.stdout
80+
# desc = np.asarray(['[' + ascii_uppercase[:self.shape[1] - 2] + ']'], dtype='c').tolist()
81+
# desc = [[c.decode('utf-8') for c in line] for line in desc]
82+
# color = 'red' if self.s == 0 else 'green' if self.s == self.nS - 1 else 'yellow'
83+
# desc[0][self.s] = utils.colorize(desc[0][self.s], color, highlight=True)
84+
# outfile.write("\n")
85+
# outfile.write("\n".join(''.join(line) for line in desc) + "\n")
86+
#
87+
# if mode != 'human':
88+
# return outfile
89+
def close(self) -> None:
90+
super().close()

0 commit comments

Comments
 (0)