diff --git a/.gitignore b/.gitignore index cdfdf567..41014431 100644 --- a/.gitignore +++ b/.gitignore @@ -92,7 +92,7 @@ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: -# .python-version +.python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. diff --git a/docs/assets/pinpad.gif b/docs/assets/pinpad.gif new file mode 100644 index 00000000..9879adce Binary files /dev/null and b/docs/assets/pinpad.gif differ diff --git a/docs/assets/pinpad_variations.gif b/docs/assets/pinpad_variations.gif new file mode 100644 index 00000000..7c59165d Binary files /dev/null and b/docs/assets/pinpad_variations.gif differ diff --git a/docs/envs/pinpad.md b/docs/envs/pinpad.md new file mode 100644 index 00000000..62596659 --- /dev/null +++ b/docs/envs/pinpad.md @@ -0,0 +1,234 @@ +--- +title: PinPad +summary: A 2D grid-based navigation task with colored pads +external_links: + arxiv: https://arxiv.org/abs/2206.04114 + github: https://github.com/danijar/director/tree/main + main_page: https://danijar.com/project/director/ +--- + +![pinpad](../assets/pinpad.gif) + +## Description + +A 2D grid-based navigation task where an agent must navigate to a target colored pad within a structured layout. The environment features multiple colored pads arranged in various configurations, with walls constraining movement. The agent receives a visual observation showing the current state and a goal image showing the target configuration. + +The task requires the agent to: +1. Identify the target colored pad from the goal image +2. Navigate through open spaces +3. Reach the target pad (either by stepping on it or being fully within its boundaries) + +**Success criteria**: The episode is successful when the agent reaches the target pad. In the continuous version (`PinPad`), the agent must have all four of its corners within the target pad, meaning that it is entirely contained in the target pad. In the discrete version (`PinPadDiscrete`), the agent succeeds by simply stepping onto the target pad, because its position is grid-aligned. + +```python +import stable_worldmodel as swm + +# Continuous action space version +world = swm.World('swm/PinPad-v0', num_envs=4, image_shape=(224, 224)) + +# Discrete action space version +world = swm.World('swm/PinPadDiscrete-v0', num_envs=4, image_shape=(224, 224)) +``` + +## Environment Specs + +### Continuous Action Space (`PinPad-v0`) + +| Property | Value | +|----------|-------| +| Action Space | `Box(-1, 1, shape=(2,))` — 2D continuous movement (dx, dy) | +| Observation Space | `Dict(image=(224, 224, 3), agent_position=(2,))` | +| Reward | 10 for reaching target, 0 otherwise | +| Episode Length | 100 steps (default) | +| Render Size | 224×224 (16×16 grid at 14× scale) | +| Grid Size | 16×16 cells | + +#### Action Details + +Actions are 2D vectors (dx, dy) clipped to [-1, 1] that move the agent continuously. The agent's center position is constrained so that it cannot move through walls (#). + +#### Observation Details + +| Key | Shape | Description | +|-----|-------|-------------| +| `image` | `(224, 224, 3)` | RGB rendering of the current state | +| `agent_position` | `(2,)` | Agent position (x, y) as floats in [1.5, 14.5] | + +#### Info Dictionary + +The `info` dict returned by `step()` and `reset()` contains: + +| Key | Description | +|-----|-------------| +| `goal` | Goal image (224, 224, 3) showing the target pad highlighted with the agent on it| +| `goal_position` | Position (x, y) of the farthest cell in target pad from center | + +### Discrete Action Space (`PinPadDiscrete-v0`) + +| Property | Value | +|----------|-------| +| Action Space | `Discrete(5)` — {no-op, up, down, right, left} | +| Observation Space | `Dict(image=(224, 224, 3), agent_position=(2,))` | +| Reward | 10 for reaching target, 0 otherwise | +| Episode Length | 100 steps (default) | +| Render Size | 224×224 (16×16 grid at 14× scale) | +| Grid Size | 16×16 cells | + +#### Action Details + +Actions are discrete movements: +- `0`: No-op (stay in place) +- `1`: Move up (dy = +1) +- `2`: Move down (dy = -1) +- `3`: Move right (dx = +1) +- `4`: Move left (dx = -1) + +The agent cannot move through walls (#). Movement is grid-aligned with integer positions. + +#### Observation Details + +| Key | Shape | Description | +|-----|-------|-------------| +| `image` | `(224, 224, 3)` | RGB rendering of the current state | +| `agent_position` | `(2,)` | Agent position (x, y) as integers in [0, 15] | + +#### Info Dictionary + +The `info` dict returned by `step()` and `reset()` contains: + +| Key | Description | +|-----|-------------| +| `goal` | Goal image (224, 224, 3) showing the target pad highlighted with the agent on it| +| `goal_position` | Grid position (x, y) of the farthest cell in target pad from center | + +## Variation Space + +![pinpad_variations](../assets/pinpad_variations.gif) + +The environment supports customization through the variation space: + +| Factor | Type | Description | +|--------|------|-------------| +| `agent.spawn` | Box([1.5, 1.5], [14.5, 14.5]) (continuous) or Discrete(196) (discrete) | Agent starting position | +| `agent.target_pad` | Box(0.0, 1.0) | Target pad selection (linearly mapped to available pads) | +| `grid.task` | Discrete(6) | Task layout: 'three', 'four', 'five', 'six', 'seven', 'eight' | + +### Task Layouts + +The environment includes 6 predefined layouts with different numbers of colored pads: + +| Task Name | Number of Pads | Description | +|-----------|----------------|-------------| +| `three` | 3 pads | Pads 1, 2, 3 arranged in T-shape | +| `four` | 4 pads | Pads 1, 2, 3, 4 in corners | +| `five` | 5 pads | Pads 1-5 with more complex layout | +| `six` | 6 pads | Pads 1-6 arranged on left and right sides | +| `seven` | 7 pads | Pads 1-7 with central pad | +| `eight` | 8 pads | Pads 1-8, placed on sides and corners | + +### Default Variations + +By default, these factors are randomized at each reset: +- `agent.spawn` — Starting position +- `agent.target_pad` — Which pad is the goal + +The task layout (`grid.task`) defaults to 'three' but can be randomized: + +```python +# Randomize task layout +world.reset(options={'variation': ['grid.task']}) + +# Randomize everything +world.reset(options={'variation': ['all']}) + +# Fix a specific task +world.reset(options={ + 'variation': ['grid.task'], + 'variation_values': {'grid.task': 5} # Use 'eight' layout +}) +``` + +### Pad Colors + +Each numbered pad has a distinct color: + +| Pad | Color (RGB) | +|-----|-------------| +| 1 | Red (255, 0, 0) | +| 2 | Green (0, 255, 0) | +| 3 | Blue (0, 0, 255) | +| 4 | Yellow (255, 255, 0) | +| 5 | Magenta (255, 0, 255) | +| 6 | Cyan (0, 255, 255) | +| 7 | Purple (128, 0, 128) | +| 8 | Teal (0, 128, 128) | + +Pads are dimmed (10% color + 90% white) when not occupied, and fully colored when the agent is on them. + +## Expert Policy + +The environment includes built-in expert policies for both continuous and discrete action spaces. These policies use optimal navigation to reach the target pad efficiently. + +### Continuous Action Space + +```python +from stable_worldmodel.envs.pinpad import ExpertPolicy + +policy = ExpertPolicy(max_norm=1.0, add_noise=True) +world.set_policy(policy) +``` + +#### ExpertPolicy Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `max_norm` | float | 1.0 | Maximum action magnitude. Actions are clipped to this L2 norm. | +| `add_noise` | bool | True | Whether to add Gaussian noise (σ=1.0) to actions before clipping. Useful for generating diverse trajectories. | + +The expert policy computes the vector from the agent's current position to the goal position, optionally adds noise, and clips the action to the maximum norm. + +### Discrete Action Space + +```python +from stable_worldmodel.envs.pinpad import ExpertPolicyDiscrete + +policy = ExpertPolicyDiscrete() +world.set_policy(policy) +``` + +The discrete expert policy uses Manhattan distance to determine the optimal action. When both horizontal and vertical movement are needed, it alternates between them. + +### Usage with Vectorized Environments + +Both expert policies work seamlessly with vectorized environments: + +```python +import stable_worldmodel as swm +from stable_worldmodel.envs.pinpad import ExpertPolicy + +# Continuous version +world = swm.World('swm/PinPad-v0', num_envs=10, image_shape=(224, 224)) +policy = ExpertPolicy(max_norm=0.8, add_noise=True) +world.set_policy(policy) + +# Collect data with random task variations +world.record_dataset( + dataset_name='pinpad_expert', + episodes=1000, + seed=42, + options={'variation': ['agent.spawn', 'agent.target_pad', 'grid.task']} +) +``` + +```python +# Discrete version +world = swm.World('swm/PinPadDiscrete-v0', num_envs=10, image_shape=(224, 224)) +policy = ExpertPolicyDiscrete() +world.set_policy(policy) + +world.record_dataset( + dataset_name='pinpad_discrete_expert', + episodes=1000, + seed=42 +) +``` diff --git a/scripts/data/collect_pinpad.py b/scripts/data/collect_pinpad.py new file mode 100644 index 00000000..4cd91ed7 --- /dev/null +++ b/scripts/data/collect_pinpad.py @@ -0,0 +1,48 @@ +import hydra +from loguru import logger as logging +import numpy as np + +import stable_worldmodel as swm + + +@hydra.main(version_base=None, config_path='./', config_name='config') +def run(cfg): + """Run data collection script""" + + world = swm.World('swm/PinPad-v0', **cfg.world) + world.set_policy(swm.envs.pinpad.expert_policy.ExpertPolicy(max_norm=0.25)) + logging.info("Set world's policy to expert policy") + + logging.info(f'Collecting data for {cfg.num_traj} trajectories') + dataset_name = 'pinpad' + world.record_dataset( + dataset_name, + episodes=cfg.num_traj, + seed=np.random.default_rng(cfg.seed).integers(0, int(2**20)).item(), + cache_dir=cfg.cache_dir, + options=cfg.get('options'), + ) + logging.success( + f' 🎉🎉🎉 Completed data collection for {dataset_name} 🎉🎉🎉' + ) + + dataset = swm.data.HDF5Dataset( + name=dataset_name, + keys_to_load=['pixels', 'agent_position', 'action'], + ) + logging.info(f'Loaded dataset from {dataset.h5_path}') + swm.utils.record_video_from_dataset( + video_path='./videos/pinpad', + dataset=dataset, + episode_idx=[0, 1, 2, 3], + max_steps=cfg.world.max_episode_steps, + fps=30, + viewname='pixels', + ) + logging.success( + f' 🎉🎉🎉 Completed video recording from dataset for {dataset_name} 🎉🎉🎉' + ) + + +if __name__ == '__main__': + run() diff --git a/scripts/data/collect_pinpad_discrete.py b/scripts/data/collect_pinpad_discrete.py new file mode 100644 index 00000000..4109cb33 --- /dev/null +++ b/scripts/data/collect_pinpad_discrete.py @@ -0,0 +1,48 @@ +import hydra +from loguru import logger as logging +import numpy as np + +import stable_worldmodel as swm + + +@hydra.main(version_base=None, config_path='./', config_name='config') +def run(cfg): + """Run data collection script""" + + world = swm.World('swm/PinPad-Discrete-v0', **cfg.world) + world.set_policy(swm.envs.pinpad.ExpertPolicyDiscrete()) + logging.info("Set world's policy to expert policy") + + logging.info(f'Collecting data for {cfg.num_traj} trajectories') + dataset_name = 'pinpad_discrete' + world.record_dataset( + dataset_name, + episodes=cfg.num_traj, + seed=np.random.default_rng(cfg.seed).integers(0, int(2**20)).item(), + cache_dir=cfg.cache_dir, + options=cfg.get('options'), + ) + logging.success( + f' 🎉🎉🎉 Completed data collection for {dataset_name} 🎉🎉🎉' + ) + + dataset = swm.data.HDF5Dataset( + name=dataset_name, + keys_to_load=['pixels', 'agent_position', 'action'], + ) + logging.info(f'Loaded dataset from {dataset.h5_path}') + swm.utils.record_video_from_dataset( + video_path='./videos/pinpad_discrete', + dataset=dataset, + episode_idx=[0, 1, 2, 3], + max_steps=cfg.world.max_episode_steps, + fps=30, + viewname='pixels', + ) + logging.success( + f' 🎉🎉🎉 Completed video recording from dataset for {dataset_name} 🎉🎉🎉' + ) + + +if __name__ == '__main__': + run() diff --git a/stable_worldmodel/envs/__init__.py b/stable_worldmodel/envs/__init__.py index bf53b44f..e826d430 100644 --- a/stable_worldmodel/envs/__init__.py +++ b/stable_worldmodel/envs/__init__.py @@ -108,6 +108,11 @@ def register(id, entry_point): entry_point='stable_worldmodel.envs.dmcontrol.quadruped:QuadrupedDMControlWrapper', ) +register( + id='swm/PinPad-v0', + entry_point='stable_worldmodel.envs.pinpad:PinPad', +) + ############ # DISCRETE # @@ -122,3 +127,8 @@ def register(id, entry_point): id='swm/PushT-Discrete-v1', entry_point='stable_worldmodel.envs.pusht:PushTDiscrete', ) + +register( + id='swm/PinPad-Discrete-v0', + entry_point='stable_worldmodel.envs.pinpad:PinPadDiscrete', +) diff --git a/stable_worldmodel/envs/pinpad/__init__.py b/stable_worldmodel/envs/pinpad/__init__.py new file mode 100644 index 00000000..ebcb09f0 --- /dev/null +++ b/stable_worldmodel/envs/pinpad/__init__.py @@ -0,0 +1,6 @@ +from .env import PinPad +from .env_discrete import PinPadDiscrete +from .expert_policy import ExpertPolicyDiscrete + + +__all__ = ['PinPad', 'PinPadDiscrete', 'ExpertPolicyDiscrete'] diff --git a/stable_worldmodel/envs/pinpad/constants.py b/stable_worldmodel/envs/pinpad/constants.py new file mode 100644 index 00000000..477dff02 --- /dev/null +++ b/stable_worldmodel/envs/pinpad/constants.py @@ -0,0 +1,141 @@ +"""Constants for the PinPad environments.""" + +COLORS = { + '1': (255, 0, 0), + '2': ( 0, 255, 0), + '3': ( 0, 0, 255), + '4': (255, 255, 0), + '5': (255, 0, 255), + '6': ( 0, 255, 255), + '7': (128, 0, 128), + '8': ( 0, 128, 128), +} + +X_BOUND = 16 +Y_BOUND = 16 +RENDER_SCALE = 14 + +TASK_NAMES = ['three', 'four', 'five', 'six', 'seven', 'eight'] + +LAYOUT_THREE = """ +################ +#1111 3333# +#1111 3333# +#1111 3333# +#1111 3333# +# # +# # +# # +# # +# # +# # +# 2222 # +# 2222 # +# 2222 # +# 2222 # +################ +""".strip('\n') + +LAYOUT_FOUR = """ +################ +#1111 4444# +#1111 4444# +#1111 4444# +#1111 4444# +# # +# # +# # +# # +# # +# # +#3333 2222# +#3333 2222# +#3333 2222# +#3333 2222# +################ +""".strip('\n') + +LAYOUT_FIVE = """ +################ +# 4444# +# 4444# +#111 4444# +#111 # +#111 # +#111 555# +# 555# +# 555# +#333 555# +#333 # +#333 # +#333 2222# +# 2222# +# 2222# +################ +""".strip('\n') + +LAYOUT_SIX = """ +################ +#111 555# +#111 555# +#111 555# +# # +# # +#33 66# +#33 66# +#33 66# +#33 66# +# # +# # +#444 222# +#444 222# +#444 222# +################ +""".strip('\n') + +LAYOUT_SEVEN = """ +################ +#111 444# +#111 444# +#11 44# +# # +# # +#33 55# +#33 55# +#33 55# +#33 55# +# # +# # +#66 22# +#666 7777 222# +#666 7777 222# +################ +""".strip('\n') + +LAYOUT_EIGHT = """ +################ +#111 8888 444# +#111 8888 444# +#11 44# +# # +# # +#33 55# +#33 55# +#33 55# +#33 55# +# # +# # +#66 22# +#666 7777 222# +#666 7777 222# +################ +""".strip('\n') + +LAYOUTS = { + 'three': LAYOUT_THREE, + 'four': LAYOUT_FOUR, + 'five': LAYOUT_FIVE, + 'six': LAYOUT_SIX, + 'seven': LAYOUT_SEVEN, + 'eight': LAYOUT_EIGHT, +} diff --git a/stable_worldmodel/envs/pinpad/env.py b/stable_worldmodel/envs/pinpad/env.py new file mode 100644 index 00000000..4096fbf0 --- /dev/null +++ b/stable_worldmodel/envs/pinpad/env.py @@ -0,0 +1,280 @@ +import gymnasium as gym +import numpy as np +from PIL import Image, ImageDraw + +from stable_worldmodel import spaces as swm_spaces +from stable_worldmodel.envs.pinpad.constants import ( + COLORS, + X_BOUND, + Y_BOUND, + RENDER_SCALE, + TASK_NAMES, + LAYOUTS, +) + + +DEFAULT_VARIATIONS = ( + 'agent.spawn', + 'agent.target_pad', +) + + +# TODO: Re-enable targets to be sequences of pads instead of single pads +# TODO: Add walls to the environment, with the number of walls controlled by +# the variation space +class PinPad(gym.Env): + def __init__( + self, + seed=None, + init_value=None, + render_mode='rgb_array', # For backward compatibility; not used + ): + # Build variation space + self.variation_space = self._build_variation_space() + if init_value is not None: + self.variation_space.set_init_value(init_value) + + # Other spaces + self.observation_space = gym.spaces.Dict( + { + 'image': gym.spaces.Box( + low=0, + high=255, + shape=(Y_BOUND * RENDER_SCALE, X_BOUND * RENDER_SCALE, 3), + dtype=np.uint8, + ), + 'agent_position': gym.spaces.Box( + low=np.array([1.5, 1.5], dtype=np.float64), + high=np.array( + [X_BOUND - 1.5, Y_BOUND - 1.5], dtype=np.float64 + ), + shape=(2,), + dtype=np.float64, + ), + } + ) + self.action_space = gym.spaces.Box( + low=-1.0, + high=1.0, + shape=(2,), + dtype=np.float64, + ) + + # To be initialized in reset() + self.task = None + self.layout = None + self.pads = None + self.player = None + self.target_pad = None + + def _build_variation_space(self): + return swm_spaces.Dict( + { + 'agent': swm_spaces.Dict( + { + 'spawn': swm_spaces.Box( + low=np.array([1.5, 1.5], dtype=np.float64), + high=np.array( + [X_BOUND - 1.5, Y_BOUND - 1.5], + dtype=np.float64, + ), + init_value=np.array( + [X_BOUND / 2, Y_BOUND / 2], dtype=np.float64 + ), + shape=(2,), + dtype=np.float64, + ), + # The number of pads is dynamic based on the task, + # so we generate the index as a float in [0, 1) and then + # scale it to the number of pads before truncating it to an int + 'target_pad': swm_spaces.Box( + low=0.0, + high=1.0, + init_value=0.0, + shape=(), + dtype=np.float64, + ), + } + ), + 'grid': swm_spaces.Dict( + { + 'task': swm_spaces.Discrete( + n=len(TASK_NAMES), + start=0, + init_value=0, # 0 = 'three', 5 = 'eight' + ), + } + ), + }, + sampling_order=['grid', 'agent'], + ) + + def _setup_layout(self, task): + layout = LAYOUTS[task] + self.layout = np.array( + [list(line) for line in layout.split('\n')] + ).T # Transposes so that actions are (dx, dy) + assert self.layout.shape == (X_BOUND, Y_BOUND), ( + f'Layout shape should be ({X_BOUND}, {Y_BOUND}), got {self.layout.shape}' + ) + + def _setup_pads(self): + self.pads = sorted( + list(set(self.layout.flatten().tolist()) - set('* #\n')) + ) + + def reset(self, seed=None, options=None): + super().reset(seed=seed, options=options) + + # Reset variation space + options = options or {} + swm_spaces.reset_variation_space( + self.variation_space, + seed, + options, + DEFAULT_VARIATIONS, + ) + + # Update task if it changed or if this is the first reset + task_idx = int(self.variation_space['grid']['task'].value) + new_task = TASK_NAMES[task_idx] + if new_task != self.task or self.task is None: + self.task = new_task + self._setup_layout(self.task) + self._setup_pads() + + # Set player position directly from variation space + spawn_position = self.variation_space['agent']['spawn'].value + self.player = tuple(spawn_position) + + # Set target pad from variation space using linear binning + target_pad_value = float( + self.variation_space['agent']['target_pad'].value + ) + target_pad_idx = int(target_pad_value * len(self.pads)) + assert target_pad_idx >= 0 and target_pad_idx < len(self.pads), ( + f'Target pad index {target_pad_idx} is out of range for {len(self.pads)} pads' + ) + self.target_pad = self.pads[target_pad_idx] + self.goal_position = self._get_goal_position(self.target_pad) + self.goal = self.render(player_position=self.goal_position) + + # Gets return values + obs = self._get_obs() + info = self._get_info() + return obs, info + + def _get_obs(self): + return { + 'image': self.render(), + 'agent_position': np.array(self.player, dtype=np.float64), + } + + def step(self, action): + # Moves player + x = np.clip(self.player[0] + action[0], 1.5, X_BOUND - 1.5) + y = np.clip(self.player[1] + action[1], 1.5, Y_BOUND - 1.5) + tile = self.layout[int(x)][int(y)] + if ( + tile != '#' + ): # TODO: Add linear interpolation in case of wall collision + self.player = (float(x), float(y)) + + # Makes observation + agent_in_target_pad = self._agent_in_target_pad( + self.player, self.target_pad + ) + obs = self._get_obs() + reward = 10.0 if agent_in_target_pad else 0.0 + terminated = agent_in_target_pad # TODO: Maybe always set to false? + truncated = False + info = self._get_info() + return obs, reward, terminated, truncated, info + + def _get_goal_position(self, target_pad): + target_cells = np.array( + list(zip(*np.where(self.layout == target_pad))), dtype=np.float64 + ) + target_cell_centers = target_cells + 0.5 + center_cell = np.array([X_BOUND / 2, Y_BOUND / 2], dtype=np.float64) + farthest_idx = np.argmax( + np.linalg.norm(target_cell_centers - center_cell, axis=1) + ) + farthest_from_center = target_cell_centers[farthest_idx] + return farthest_from_center + + def _agent_in_target_pad(self, player, target_pad): + # Gets all cells that overlap with the agent + corner_deltas = np.array( + [ + (-0.5, -0.5), + (-0.5, 0.5), + (0.5, -0.5), + (0.5, 0.5), + ], + dtype=np.float64, + ) + corner_positions = player + corner_deltas + distinct_corner_positions = [ + tuple(pos) + for pos in np.unique(corner_positions.astype(int), axis=0) + ] + + # Gets all cells from the target pad + target_cells = np.array( + list(zip(*np.where(self.layout == target_pad))), dtype=np.float64 + ) + target_cells = [tuple(pos) for pos in target_cells.astype(int)] + + # Checks that the agent is entirely within the target pad + for pos in distinct_corner_positions: + if pos not in target_cells: + return False + return True + + def _get_info(self): + info = { + 'goal_position': np.array(self.goal_position), + 'goal': self.goal, + } + return info + + def render(self, player_position=None): + # Sets up grid + grid = np.zeros((X_BOUND, Y_BOUND, 3), np.uint8) + 255 + white = np.array([255, 255, 255]) + if player_position is None: + player_position = self.player + + # Colors all cells except agent + for (x, y), char in np.ndenumerate(self.layout): + if char == '#': + grid[x, y] = (192, 192, 192) # Gray + elif char in self.pads: + color = np.array(COLORS[char]) + color = ( + color + if self._agent_in_target_pad(player_position, char) + else (10 * color + 90 * white) / 100 + ) + grid[x, y] = color + + # Scales up and transposes grid + image = np.repeat(np.repeat(grid, RENDER_SCALE, 0), RENDER_SCALE, 1) + image = image.transpose((1, 0, 2)) + + # Places agent with anti-aliasing + image_pil = Image.fromarray(image, mode='RGB') + draw = ImageDraw.Draw(image_pil) + x, y = player_position + draw.rectangle( + [ + (x - 0.5) * RENDER_SCALE, + (y - 0.5) * RENDER_SCALE, + (x + 0.5) * RENDER_SCALE, + (y + 0.5) * RENDER_SCALE, + ], + fill=(0, 0, 0), # Agent is black + ) + image = np.asarray(image_pil) + return image diff --git a/stable_worldmodel/envs/pinpad/env_discrete.py b/stable_worldmodel/envs/pinpad/env_discrete.py new file mode 100644 index 00000000..b8ad34d1 --- /dev/null +++ b/stable_worldmodel/envs/pinpad/env_discrete.py @@ -0,0 +1,232 @@ +import gymnasium as gym +import numpy as np + +from stable_worldmodel import spaces as swm_spaces +from stable_worldmodel.envs.pinpad.constants import ( + COLORS, + X_BOUND, + Y_BOUND, + RENDER_SCALE, + TASK_NAMES, + LAYOUTS, +) + + +DEFAULT_VARIATIONS = ( + 'agent.spawn', + 'agent.target_pad', +) + + +# TODO: Re-enable targets to be sequences of pads instead of single pads +# TODO: Add walls to the environment, with the number of walls controlled by +# the variation space +class PinPadDiscrete(gym.Env): + def __init__( + self, + seed=None, + init_value=None, + render_mode='rgb_array', # For backward compatibility; not used + ): + # Build variation space + self.variation_space = self._build_variation_space() + if init_value is not None: + self.variation_space.set_init_value(init_value) + + # Other spaces + self.observation_space = gym.spaces.Dict( + { + 'image': gym.spaces.Box( + low=0, + high=255, + shape=(Y_BOUND * RENDER_SCALE, X_BOUND * RENDER_SCALE, 3), + dtype=np.uint8, + ), + 'agent_position': gym.spaces.Box( + low=np.array([0, 0], dtype=np.int64), + high=np.array([X_BOUND - 1, Y_BOUND - 1], dtype=np.int64), + shape=(2,), + dtype=np.int64, + ), + } + ) + self.action_space = gym.spaces.Discrete(5) # [0, 5) + + # To be initialized in reset() + self.task = None + self.layout = None + self.pads = None + self.spawns = None + self.player = None + self.target_pad = None + + def _build_variation_space(self): + # Spawn locations don't include walls + max_spawns = X_BOUND * Y_BOUND - 2 * (X_BOUND + Y_BOUND - 2) + + return swm_spaces.Dict( + { + 'agent': swm_spaces.Dict( + { + 'spawn': swm_spaces.Discrete( + n=max_spawns, + start=0, + init_value=0, + ), + # The number of pads is dynamic based on the task, + # so we generate the index as a float in [0, 1) and then + # scale it to the number of pads before truncating it to an int + 'target_pad': swm_spaces.Box( + low=0.0, + high=1.0, + init_value=0.0, + shape=(), + dtype=np.float64, + ), + } + ), + 'grid': swm_spaces.Dict( + { + 'task': swm_spaces.Discrete( + n=len(TASK_NAMES), + start=0, + init_value=0, # 0 = 'three', 5 = 'eight' + ), + } + ), + }, + sampling_order=['grid', 'agent'], + ) + + def _setup_layout(self, task): + layout = LAYOUTS[task] + self.layout = np.array( + [list(line) for line in layout.split('\n')] + ).T # Transposes so that actions are (dx, dy) + assert self.layout.shape == (X_BOUND, Y_BOUND), ( + f'Layout shape should be ({X_BOUND}, {Y_BOUND}), got {self.layout.shape}' + ) + + def _setup_pads_and_spawns(self): + self.pads = sorted( + list(set(self.layout.flatten().tolist()) - set('* #\n')) + ) + self.spawns = [] + for (x, y), char in np.ndenumerate(self.layout): + if char != '#': + self.spawns.append((x, y)) + + def reset(self, seed=None, options=None): + super().reset(seed=seed, options=options) + + # Reset variation space + options = options or {} + swm_spaces.reset_variation_space( + self.variation_space, + seed, + options, + DEFAULT_VARIATIONS, + ) + + # Update task if it changed or if this is the first reset + task_idx = int(self.variation_space['grid']['task'].value) + new_task = TASK_NAMES[task_idx] + if new_task != self.task or self.task is None: + self.task = new_task + self._setup_layout(self.task) + self._setup_pads_and_spawns() + + # Set player position from variation space (index into spawns) + spawn_idx = int(self.variation_space['agent']['spawn'].value) + assert spawn_idx >= 0 and spawn_idx < len(self.spawns), ( + f'Spawn index {spawn_idx} is out of range for {len(self.spawns)} spawns' + ) + self.player = self.spawns[spawn_idx] + + # Set target pad from variation space using linear binning + target_pad_value = float( + self.variation_space['agent']['target_pad'].value + ) + target_pad_idx = int(target_pad_value * len(self.pads)) + assert target_pad_idx >= 0 and target_pad_idx < len(self.pads), ( + f'Target pad index {target_pad_idx} is out of range for {len(self.pads)} pads' + ) + self.target_pad = self.pads[target_pad_idx] + self.goal_position = self._get_goal_position(self.target_pad) + self.goal = self.render(player_position=self.goal_position) + + # Gets return values + obs = self._get_obs() + info = self._get_info() + return obs, info + + def _get_obs(self): + return { + 'image': self.render(), + 'agent_position': np.array(self.player, dtype=np.int64), + } + + def step(self, action): + # Moves player + move = [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)][action] + x = np.clip(self.player[0] + move[0], 0, X_BOUND - 1) + y = np.clip(self.player[1] + move[1], 0, Y_BOUND - 1) + tile = self.layout[x][y] + if tile != '#': + self.player = (x, y) + + # Gets reward + reward = 0.0 + if tile == self.target_pad: + reward += 10.0 + + # Makes observation + obs = self._get_obs() + terminated = tile == self.target_pad + truncated = False + info = self._get_info() + return obs, reward, terminated, truncated, info + + def _get_goal_position(self, target_pad): + target_cells = list(zip(*np.where(self.layout == target_pad))) + center_cell = (X_BOUND // 2, Y_BOUND // 2) + farthest_idx = np.argmax( + np.linalg.norm( + np.array(target_cells) - np.array(center_cell), axis=1 + ) + ) + farthest_from_center = target_cells[farthest_idx] + return farthest_from_center + + def _get_info(self): + info = { + 'goal_position': np.array(self.goal_position), + 'goal': self.goal, + } + return info + + def render(self, player_position=None): + # Sets up grid + grid = np.zeros((X_BOUND, Y_BOUND, 3), np.uint8) + 255 + white = np.array([255, 255, 255]) + if player_position is None: + player_position = self.player + current = self.layout[player_position[0]][player_position[1]] + + # Colors all cells + for (x, y), char in np.ndenumerate(self.layout): + if char == '#': + grid[x, y] = (192, 192, 192) # Gray + elif char in self.pads: + color = np.array(COLORS[char]) + color = ( + color + if char == current + else (10 * color + 90 * white) / 100 + ) + grid[x, y] = color + grid[player_position] = (0, 0, 0) + + # Scales up + image = np.repeat(np.repeat(grid, RENDER_SCALE, 0), RENDER_SCALE, 1) + return image.transpose((1, 0, 2)) diff --git a/stable_worldmodel/envs/pinpad/expert_policy.py b/stable_worldmodel/envs/pinpad/expert_policy.py new file mode 100644 index 00000000..da2159d1 --- /dev/null +++ b/stable_worldmodel/envs/pinpad/expert_policy.py @@ -0,0 +1,136 @@ +import numpy as np +from stable_worldmodel.policy import BasePolicy + + +def compute_action_discrete(agent_position, goal_position): + dx, dy = (goal_position - agent_position).tolist() + if abs(dx) + abs(dy): + # Gets directions we need to move in (in the transposed space) + possible_actions = [] + if abs(dx): + if dx > 0: + possible_actions.append(3) # right + else: + possible_actions.append(4) # left + if abs(dy): + if dy > 0: + possible_actions.append(1) # up + else: + possible_actions.append(2) # down + + # Alternates between horizontal and vertical moves + if len(possible_actions) == 2: + action = possible_actions[(abs(dx) + abs(dy)) % 2] + else: + action = possible_actions[0] + else: + action = 0 + return action + + +def compute_action_continuous( + agent_position, goal_position, max_norm, add_noise, rng +): + delta = goal_position - agent_position + if add_noise: + delta = delta + rng.normal(0, 1, delta.shape) + + if np.linalg.norm(delta) > max_norm: + action = max_norm * delta / np.linalg.norm(delta) # Clips norm + else: + action = delta + return action + + +def get_action(info_dict, env, env_type, **kwargs): + # Check if environment is vectorized + base_env = env.unwrapped + if hasattr(base_env, 'envs'): + envs = [e.unwrapped for e in base_env.envs] + is_vectorized = True + else: + envs = [base_env] + is_vectorized = False + + # Computes actions for each environment + actions = [] + dtype = np.int64 if env_type == 'discrete' else np.float64 + for i, env in enumerate(envs): + if is_vectorized: + agent_position = np.asarray( + info_dict['agent_position'][i], dtype=dtype + ).squeeze() + goal_position = np.asarray( + info_dict['goal_position'][i], dtype=dtype + ).squeeze() + else: + agent_position = np.asarray( + info_dict['agent_position'], dtype=dtype + ).squeeze() + goal_position = np.asarray( + info_dict['goal_position'], dtype=dtype + ).squeeze() + + if env_type == 'discrete': + actions.append( + compute_action_discrete(agent_position, goal_position) + ) + elif env_type == 'continuous': + actions.append( + compute_action_continuous( + agent_position, + goal_position, + kwargs['max_norm'], + kwargs['add_noise'], + kwargs['rng'], + ) + ) + else: + raise ValueError(f'Invalid environment type: {env_type}') + + actions = np.array(actions) + return actions if is_vectorized else actions[0] + + +class ExpertPolicyDiscrete(BasePolicy): + """Expert policy for the PinPadDiscrete environment.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.type = 'expert' + + def get_action(self, info_dict, **kwargs): + assert hasattr(self, 'env'), 'Environment not set for the policy' + assert 'agent_position' in info_dict, ( + 'Agent position must be provided in info_dict' + ) + assert 'goal_position' in info_dict, ( + 'Goal position must be provided in info_dict' + ) + + return get_action(info_dict, self.env, 'discrete', **kwargs) + + +class ExpertPolicy(BasePolicy): + """Expert policy for the PinPad environment.""" + + def __init__(self, max_norm=1.0, add_noise=True, seed=None, **kwargs): + super().__init__(**kwargs) + self.type = 'expert' + self.max_norm = max_norm + self.add_noise = add_noise + self.rng = np.random.default_rng(seed) + + def get_action(self, info_dict, **kwargs): + assert hasattr(self, 'env'), 'Environment not set for the policy' + assert 'agent_position' in info_dict, ( + 'Agent position must be provided in info_dict' + ) + assert 'goal_position' in info_dict, ( + 'Goal position must be provided in info_dict' + ) + + kwargs['max_norm'] = self.max_norm + kwargs['add_noise'] = self.add_noise + kwargs['rng'] = self.rng + return get_action(info_dict, self.env, 'continuous', **kwargs) diff --git a/stable_worldmodel/utils.py b/stable_worldmodel/utils.py index 75b3b537..cd1ca6cc 100644 --- a/stable_worldmodel/utils.py +++ b/stable_worldmodel/utils.py @@ -51,9 +51,7 @@ def pretraining( logging.info('🏁🏁🏁 Pretraining script finished 🏁🏁🏁') -def flatten_dict( - d: dict, parent_key: str = '', sep: str = '.' -) -> dict: +def flatten_dict(d: dict, parent_key: str = '', sep: str = '.') -> dict: """Flatten a nested dictionary into a single-level dictionary. Args: @@ -97,6 +95,7 @@ def record_video_from_dataset( max_steps: int = 500, fps: int = 30, viewname: str | list[str] = 'pixels', + suffix: str = '.mp4', ) -> None: """Replay stored dataset episodes and export them as MP4 videos. @@ -107,6 +106,7 @@ def record_video_from_dataset( max_steps: Maximum frames per video. fps: Frames per second for the output video. viewname: Key(s) in the dataset to use as video frames. + suffix: Suffix for the output video file. """ import imageio @@ -120,10 +120,18 @@ def record_video_from_dataset( ) for ep_idx in episode_idx: - file_path = Path(video_path, f'episode_{ep_idx}.mp4') + file_path = Path(video_path, f'episode_{ep_idx}{suffix}') steps = dataset.load_episode(ep_idx) frames = np.concatenate([steps[v].numpy() for v in viewname], axis=2) frames = frames[:max_steps] - imageio.mimsave(file_path, frames.transpose(0, 2, 3, 1), fps=fps) + + kwargs = {'fps': fps} + if suffix.lower() == '.gif': + kwargs['loop'] = 0 + imageio.mimsave( + file_path.with_suffix(suffix), + frames.transpose(0, 2, 3, 1), + **kwargs, + ) print(f'Video saved to {video_path}')