|
| 1 | +import copy |
| 2 | +from typing import Dict, Optional, Tuple |
| 3 | + |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from examples.agentscope_frozenlake.utils import generate_random_map, get_goal_position |
| 7 | +from trinity.utils.log import get_logger |
| 8 | + |
| 9 | +try: |
| 10 | + from gymnasium.envs.toy_text.frozen_lake import FrozenLakeEnv as GymFrozenLakeEnv |
| 11 | +except ImportError: |
| 12 | + GymFrozenLakeEnv = object |
| 13 | + |
| 14 | + |
| 15 | +class FrozenLakeEnv(GymFrozenLakeEnv): |
| 16 | + # Map gym state in integer |
| 17 | + MAP_LOOKUP = { |
| 18 | + b"P": 0, |
| 19 | + b"F": 1, |
| 20 | + b"H": 2, |
| 21 | + b"G": 3, |
| 22 | + } |
| 23 | + |
| 24 | + # Define rules to transform to rendered text observation of the environment |
| 25 | + GRID_LOOKUP = { |
| 26 | + 0: " P \t", # player |
| 27 | + 1: " _ \t", # frozen |
| 28 | + 2: " O \t", # hole |
| 29 | + 3: " G \t", # goal |
| 30 | + 4: " X \t", # player fall into hole |
| 31 | + 5: " √ \t", # player on goal |
| 32 | + } |
| 33 | + |
| 34 | + ACTION_LOOKUP = { |
| 35 | + "still": 0, |
| 36 | + "left": 1, |
| 37 | + "down": 2, |
| 38 | + "right": 3, |
| 39 | + "up": 4, |
| 40 | + } |
| 41 | + |
| 42 | + INVALID_ACTION = 0 |
| 43 | + PENALTY_FOR_INVALID = -1 |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + max_steps: int = 8, |
| 48 | + desc: Optional[str] = None, |
| 49 | + is_slippery: bool = False, |
| 50 | + size: int = 8, |
| 51 | + p: float = 0.8, |
| 52 | + seed: int = 42, |
| 53 | + ): |
| 54 | + self.logger = get_logger() |
| 55 | + self.max_steps = max_steps or 8 |
| 56 | + self.desc = desc |
| 57 | + self.is_slippery = is_slippery |
| 58 | + self.size = size |
| 59 | + self.p = p |
| 60 | + self.seed = seed |
| 61 | + try: |
| 62 | + import gymnasium as gym |
| 63 | + from gymnasium.envs.toy_text.frozen_lake import ( |
| 64 | + FrozenLakeEnv as GymFrozenLakeEnv, |
| 65 | + ) |
| 66 | + except ImportError as e: |
| 67 | + error_message = ( |
| 68 | + f"Gymnasium is not installed. Please install gymnasium first before " |
| 69 | + f"running the frozen_lake workflow. Error: {str(e)}" |
| 70 | + ) |
| 71 | + self.logger.error(error_message) |
| 72 | + raise ImportError(error_message) |
| 73 | + |
| 74 | + if self.desc is None: |
| 75 | + random_map, goal_position = generate_random_map( |
| 76 | + size=self.size, p=self.p, seed=self.seed, max_steps=self.max_steps |
| 77 | + ) |
| 78 | + else: |
| 79 | + random_map = np.asarray(copy.deepcopy(self.desc), dtype="c") |
| 80 | + goal_position = get_goal_position(random_map) |
| 81 | + |
| 82 | + self.goal_position = goal_position |
| 83 | + |
| 84 | + GymFrozenLakeEnv.__init__(self, desc=random_map[:], is_slippery=self.is_slippery) |
| 85 | + self.action_space = gym.spaces.Discrete(4, start=1) |
| 86 | + |
| 87 | + self.map_kwargs = { |
| 88 | + "size": size, |
| 89 | + "p": p, |
| 90 | + } |
| 91 | + self.env_kwargs = { |
| 92 | + "is_slippery": is_slippery, |
| 93 | + "desc": copy.deepcopy(desc), |
| 94 | + "seed": seed, |
| 95 | + } |
| 96 | + |
| 97 | + self.action_map = { |
| 98 | + 1: 0, # left |
| 99 | + 2: 1, # down |
| 100 | + 3: 2, # right |
| 101 | + 4: 3, # up |
| 102 | + } |
| 103 | + |
| 104 | + def _get_player_position(self) -> Tuple[int, int]: |
| 105 | + return (self.s // self.ncol, self.s % self.ncol) # (row, col) |
| 106 | + |
| 107 | + def step(self, action: str) -> Tuple[str, float, bool, Dict]: |
| 108 | + """Execute a step in the environment. |
| 109 | +
|
| 110 | + Maps custom action to gymnasium FrozenLakeEnv action and takes the step. |
| 111 | + Checks if the action is effective (whether player moves in the env). |
| 112 | +
|
| 113 | + Args: |
| 114 | + action: The action to take. |
| 115 | +
|
| 116 | + Returns: |
| 117 | + Tuple of (observation, reward, done, info). |
| 118 | + """ |
| 119 | + if self.success(): |
| 120 | + return self.render(), 1, True, {"action_is_effective": False} |
| 121 | + |
| 122 | + action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0) |
| 123 | + |
| 124 | + if not action_id: |
| 125 | + action_id = self.INVALID_ACTION |
| 126 | + |
| 127 | + if action_id == self.INVALID_ACTION or action_id not in self.action_map: |
| 128 | + return self.render(), 0, False, {"action_is_effective": False} |
| 129 | + |
| 130 | + prev_player_position = int(self.s) |
| 131 | + |
| 132 | + player_pos, reward, done, _, _ = GymFrozenLakeEnv.step(self, self.action_map[action_id]) |
| 133 | + |
| 134 | + obs = self.render() |
| 135 | + return obs, reward, done, {"action_is_effective": prev_player_position != int(player_pos)} |
| 136 | + |
| 137 | + def render(self, mode="tiny_rgb_array"): |
| 138 | + """Render the environment. |
| 139 | +
|
| 140 | + Args: |
| 141 | + mode: Rendering mode. Options: "tiny_rgb_array", "list", "state", "rgb_array", "ansi". |
| 142 | +
|
| 143 | + Returns: |
| 144 | + Rendered observation based on the mode. |
| 145 | + """ |
| 146 | + assert mode in ["tiny_rgb_array", "list", "state", "rgb_array", "ansi"] |
| 147 | + if mode in ["rgb_array", "ansi"]: |
| 148 | + prev_render_mode = self.render_mode |
| 149 | + self.render_mode = mode |
| 150 | + obs = GymFrozenLakeEnv.render(self) |
| 151 | + self.render_mode = prev_render_mode |
| 152 | + return obs |
| 153 | + room_state = copy.deepcopy(self.desc) |
| 154 | + |
| 155 | + # replace the position of start 'S' with 'F' |
| 156 | + position_S = np.where(room_state == b"S") |
| 157 | + room_state[position_S] = b"F" |
| 158 | + |
| 159 | + # replace the position of the player with 'P' |
| 160 | + position_P = self._get_player_position() |
| 161 | + room_state[position_P] = b"P" |
| 162 | + |
| 163 | + if mode == "state": |
| 164 | + # transform 'S', 'F', 'H', 'G' to numpy integer array |
| 165 | + room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state) |
| 166 | + # add player in hole or player on goal |
| 167 | + if self.desc[position_P] == b"H": |
| 168 | + room_state[position_P] = 4 |
| 169 | + elif self.desc[position_P] == b"G": |
| 170 | + room_state[position_P] = 5 |
| 171 | + return room_state |
| 172 | + |
| 173 | + room_state = self.render(mode="state").tolist() |
| 174 | + |
| 175 | + if mode == "list": |
| 176 | + |
| 177 | + def lookup(cell): |
| 178 | + return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip() |
| 179 | + |
| 180 | + return [" ".join(lookup(cell) for cell in row) for row in room_state] |
| 181 | + |
| 182 | + if mode == "tiny_rgb_array": |
| 183 | + |
| 184 | + def lookup(cell): |
| 185 | + return self.GRID_LOOKUP.get(cell, "?") |
| 186 | + |
| 187 | + result = "\n".join("".join(lookup(cell) for cell in row) for row in room_state) |
| 188 | + return result |
| 189 | + |
| 190 | + def reset(self, task: Optional[Dict] = None): |
| 191 | + task = task or {} |
| 192 | + self.__init__( # type: ignore [misc] |
| 193 | + size=task.get("size", self.map_kwargs["size"]), |
| 194 | + p=task.get("p", self.map_kwargs["p"]), |
| 195 | + seed=task.get("seed", self.env_kwargs["seed"]), |
| 196 | + is_slippery=task.get("is_slippery", self.env_kwargs["is_slippery"]), |
| 197 | + ) |
| 198 | + GymFrozenLakeEnv.reset(self, seed=self.seed) |
| 199 | + return self.render(mode="tiny_rgb_array"), {} |
| 200 | + |
| 201 | + def finished(self) -> bool: |
| 202 | + player_pos = self._get_player_position() |
| 203 | + return self.desc[player_pos] in b"GH" # type: ignore [index,operator] |
| 204 | + |
| 205 | + def success(self): |
| 206 | + """ |
| 207 | + Check if the agent has reached the goal (G). |
| 208 | + """ |
| 209 | + player_pos = self._get_player_position() |
| 210 | + return self.desc[player_pos] in b"G" |
0 commit comments