Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions embodichain/agents/rl/algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import Dict, Any, Optional, Callable
from typing import Dict, Any, Callable
import torch


Expand All @@ -42,7 +42,7 @@ def collect_rollout(
policy,
obs: torch.Tensor,
num_steps: int,
on_step_callback: Optional[Callable] = None,
on_step_callback: Callable | None = None,
) -> Dict[str, Any]:
"""Collect trajectories and return logging info (e.g., reward components)."""
raise NotImplementedError
Expand Down
6 changes: 3 additions & 3 deletions embodichain/agents/rl/algo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# ----------------------------------------------------------------------------

import torch
from typing import Dict, Any, Tuple, Callable, Optional
from typing import Dict, Any, Tuple, Callable

from embodichain.agents.rl.utils import AlgorithmCfg
from embodichain.agents.rl.buffer import RolloutBuffer
Expand All @@ -41,7 +41,7 @@ def __init__(self, cfg: PPOCfg, policy):
self.policy = policy
self.device = torch.device(cfg.device)
self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate)
self.buffer: Optional[RolloutBuffer] = None
self.buffer: RolloutBuffer | None = None
# no per-rollout aggregation for dense logging

def _compute_gae(
Expand Down Expand Up @@ -76,7 +76,7 @@ def collect_rollout(
policy,
obs: torch.Tensor,
num_steps: int,
on_step_callback: Optional[Callable] = None,
on_step_callback: Callable | None = None,
) -> Dict[str, Any]:
"""Collect a rollout. Algorithm controls the data collection process."""
if self.buffer is None:
Expand Down
2 changes: 1 addition & 1 deletion embodichain/agents/rl/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from functools import reduce
from typing import Iterable, List, Optional, Sequence, Tuple, Union
from typing import Iterable, List, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand Down
2 changes: 1 addition & 1 deletion embodichain/agents/rl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import Dict, Any, Tuple, Callable, Optional
from typing import Dict, Any, Tuple, Callable
import time
import numpy as np
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import matplotlib.pyplot as plt

from copy import deepcopy
from typing import Dict, Tuple, Union, List, Callable, Any, Optional
from typing import Dict, Tuple, Union, List, Callable, Any
from tqdm import tqdm
from functools import partial

Expand Down
6 changes: 3 additions & 3 deletions embodichain/lab/gym/envs/action_bank/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np

from copy import deepcopy
from typing import List, Union, Optional
from typing import List

from embodichain.utils import logger
from embodichain.lab.gym.utils.misc import validation_with_process_from_name
Expand All @@ -33,7 +33,7 @@ def generate_affordance_from_src(
env,
src_key: str,
dst_key: str,
valid_funcs_name_kwargs_proc: Optional[List] = None,
valid_funcs_name_kwargs_proc: list | None = None,
to_array: bool = True,
) -> bool:
"""Generate a new affordance entry in env.affordance_datas by applying a validation and processing
Expand All @@ -43,7 +43,7 @@ def generate_affordance_from_src(
env: The environment object containing affordance data.
src_key (str): The key of the source affordance in env.affordance_datas.
dst_key (str): The key to store the generated affordance in env.affordance_datas.
valid_funcs_name_kwargs_proc (Optional[List]): A list of validation or processing functions (with kwargs)
valid_funcs_name_kwargs_proc (list | None): A list of validation or processing functions (with kwargs)
to apply to the source affordance. Defaults to an empty list.
to_array (bool): Whether to convert the result to a numpy array before storing. Defaults to True.

Expand Down
8 changes: 4 additions & 4 deletions embodichain/lab/gym/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
import gymnasium as gym

from typing import Dict, List, Union, Tuple, Any, Optional, Sequence
from typing import Dict, List, Union, Tuple, Any, Sequence
from functools import cached_property

from embodichain.lab.sim.types import EnvObs, EnvAction
Expand All @@ -41,7 +41,7 @@ class EnvCfg:
sim_cfg: SimulationManagerCfg = SimulationManagerCfg()
"""Simulation configuration for the environment."""

seed: Optional[int] = None
seed: int | None = None
"""The seed for the random number generator. Defaults to -1, in which case the seed is not set.

Note:
Expand Down Expand Up @@ -272,7 +272,7 @@ def _update_sim_state(self, **kwargs):
# TODO: Add randomization event here.
pass

def _initialize_episode(self, env_ids: Optional[Sequence[int]] = None, **kwargs):
def _initialize_episode(self, env_ids: Sequence[int] | None = None, **kwargs):
"""Initialize the simulation assets before each episode. Randomization can be performed at this stage.

Args:
Expand Down Expand Up @@ -427,7 +427,7 @@ def _step_action(self, action: EnvAction) -> EnvAction:
pass

def reset(
self, seed: Optional[int] = None, options: Optional[Dict] = None
self, seed: int | None = None, options: dict | None = None
) -> Tuple[EnvObs, Dict]:
"""Reset the SimulationManager environment and return the observation and info.

Expand Down
14 changes: 7 additions & 7 deletions embodichain/lab/gym/envs/embodied_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import gymnasium as gym

from dataclasses import MISSING
from typing import Dict, Union, Optional, Sequence, Tuple, Any, List
from typing import Dict, Union, Sequence, Tuple, Any, List

from embodichain.lab.sim.cfg import (
RobotCfg,
Expand Down Expand Up @@ -253,7 +253,7 @@ def get_affordance(self, key: str, default: Any = None):
return self.affordance_datas.get(key, default)

def reset(
self, seed: Optional[int] = None, options: Optional[Dict] = None
self, seed: int | None = None, options: dict | None = None
) -> Tuple[EnvObs, Dict]:
obs, info = super().reset(seed=seed, options=options)

Expand Down Expand Up @@ -297,7 +297,7 @@ def _update_sim_state(self, **kwargs) -> None:
self.event_manager.apply(mode="interval")

def _initialize_episode(
self, env_ids: Optional[Sequence[int]] = None, **kwargs
self, env_ids: Sequence[int] | None = None, **kwargs
) -> None:
# apply events such as randomization for environments that need a reset
if self.cfg.events:
Expand Down Expand Up @@ -439,28 +439,28 @@ def preview_sensor_data(
plt.imshow(view)
plt.savefig(f"sensor_data_{data_type}.png")

def create_demo_action_list(self, *args, **kwargs) -> Optional[Sequence[EnvAction]]:
def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None:
"""Create a demonstration action list for the environment.

This function should be implemented in subclasses to generate a sequence of actions
that demonstrate a specific task or behavior within the environment.

Returns:
Optional[Sequence[EnvAction]]: A list of actions if a demonstration is available, otherwise None.
Sequence[EnvAction] | None: A list of actions if a demonstration is available, otherwise None.
"""
raise NotImplementedError(
"The method 'create_demo_action_list' must be implemented in subclasses."
)

def to_dataset(self, id: str, save_path: str = None) -> Optional[str]:
def to_dataset(self, id: str, save_path: str = None) -> str | None:
"""Convert the recorded episode data to a dataset format.

Args:
id (str): Unique identifier for the dataset.
save_path (str, optional): Path to save the dataset. If None, use config or default.

Returns:
Optional[str]: The path to the saved dataset, or None if failed.
str | None: The path to the saved dataset, or None if failed.
"""
raise NotImplementedError(
"The method 'to_dataset' will be implemented in the near future."
Expand Down
10 changes: 5 additions & 5 deletions embodichain/lab/gym/envs/managers/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import os
import random
from typing import TYPE_CHECKING, Literal, Union, Optional, List, Dict, Sequence
from typing import TYPE_CHECKING, Literal, Union, List, Dict, Sequence

from embodichain.lab.sim.objects import RigidObject, Articulation, Robot
from embodichain.lab.sim.sensors import Camera, StereoCamera
Expand Down Expand Up @@ -388,7 +388,7 @@ def _project_3d_to_2d(
return points_2d

def _get_gripper_ratio(
self, control_part: str, gripper_qpos: Optional[torch.Tensor] = None
self, control_part: str, gripper_qpos: torch.Tensor | None = None
):
robot: Robot = self._env.robot
gripper_max_limit = robot.body_data.qpos_limits[
Expand All @@ -402,11 +402,11 @@ def _get_gripper_ratio(

def _get_robot_exteroception(
self,
control_part: Optional[str] = None,
control_part: str | None = None,
x_interval: float = 0.02,
y_interval: float = 0.02,
kpnts_number: int = 12,
offset: Optional[Union[List, torch.Tensor]] = None,
offset: list | torch.Tensor | None = None,
follow_eef: bool = False,
) -> torch.Tensor:
"""Get the robot exteroception poses.
Expand Down Expand Up @@ -468,7 +468,7 @@ def _get_object_exteroception(
y_interval: float = 0.02,
kpnts_number: int = 12,
is_arena_coord: bool = False,
follow_eef: Optional[str] = None,
follow_eef: str | None = None,
) -> torch.Tensor:
"""Get the rigid object exteroception poses.

Expand Down
50 changes: 25 additions & 25 deletions embodichain/lab/gym/envs/managers/randomization/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import os
import random
from typing import TYPE_CHECKING, Literal, Union, Optional, Dict
from typing import TYPE_CHECKING, Literal, Union, Dict

from embodichain.lab.sim.objects import Light, RigidObject, Articulation
from embodichain.lab.sim.sensors import Camera, StereoCamera
Expand Down Expand Up @@ -55,11 +55,11 @@ def randomize_camera_extrinsics(
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
entity_cfg: SceneEntityCfg,
pos_range: Optional[tuple[list[float], list[float]]] = None,
euler_range: Optional[tuple[list[float], list[float]]] = None,
eye_range: Optional[tuple[list[float], list[float]]] = None,
target_range: Optional[tuple[list[float], list[float]]] = None,
up_range: Optional[tuple[list[float], list[float]]] = None,
pos_range: tuple[list[float], list[float]] | None = None,
euler_range: tuple[list[float], list[float]] | None = None,
eye_range: tuple[list[float], list[float]] | None = None,
target_range: tuple[list[float], list[float]] | None = None,
up_range: tuple[list[float], list[float]] | None = None,
) -> None:
"""
Randomize camera extrinsic properties (position and orientation).
Expand Down Expand Up @@ -177,9 +177,9 @@ def randomize_light(
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
entity_cfg: SceneEntityCfg,
position_range: Optional[tuple[list[float], list[float]]] = None,
color_range: Optional[tuple[list[float], list[float]]] = None,
intensity_range: Optional[tuple[float, float]] = None,
position_range: tuple[list[float], list[float]] | None = None,
color_range: tuple[list[float], list[float]] | None = None,
intensity_range: tuple[float, float] | None = None,
) -> None:
"""Randomize light properties by adding, scaling, or setting random values.

Expand All @@ -205,9 +205,9 @@ def randomize_light(
env (EmbodiedEnv): The environment instance.
env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization.
entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize.
position_range (Optional[tuple[list[float], list[float]]]): The range for the position randomization.
color_range (Optional[tuple[list[float], list[float]]]): The range for the color randomization.
intensity_range (Optional[tuple[float, float]]): The range for the intensity randomization.
position_range (tuple[list[float], list[float]] | None): The range for the position randomization.
color_range (tuple[list[float], list[float]] | None): The range for the color randomization.
intensity_range (tuple[float, float] | None): The range for the intensity randomization.
"""

light: Light = env.sim.get_light(entity_cfg.uid)
Expand Down Expand Up @@ -259,10 +259,10 @@ def randomize_camera_intrinsics(
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
entity_cfg: SceneEntityCfg,
focal_x_range: Optional[tuple[float, float]] = None,
focal_y_range: Optional[tuple[float, float]] = None,
cx_range: Optional[tuple[float, float]] = None,
cy_range: Optional[tuple[float, float]] = None,
focal_x_range: tuple[float, float] | None = None,
focal_y_range: tuple[float, float] | None = None,
cx_range: tuple[float, float] | None = None,
cy_range: tuple[float, float] | None = None,
) -> None:
"""Randomize camera intrinsic properties by adding, scaling, or setting random values.

Expand All @@ -289,10 +289,10 @@ def randomize_camera_intrinsics(
env (EmbodiedEnv): The environment instance.
env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization.
entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize.
focal_x_range (Optional[tuple[float, float]]): The range for the focal length x randomization.
focal_y_range (Optional[tuple[float, float]]): The range for the focal length y randomization.
cx_range (Optional[tuple[float, float]]): The range for the principal point x randomization.
cy_range (Optional[tuple[float, float]]): The range for the principal point y randomization.
focal_x_range (tuple[float, float] | None): The range for the focal length x randomization.
focal_y_range (tuple[float, float] | None): The range for the focal length y randomization.
cx_range (tuple[float, float] | None): The range for the principal point x randomization.
cy_range (tuple[float, float] | None): The range for the principal point y randomization.
"""

camera: Union[Camera, StereoCamera] = env.sim.get_sensor(entity_cfg.uid)
Expand Down Expand Up @@ -500,11 +500,11 @@ def __call__(
env_ids: Union[torch.Tensor, None],
entity_cfg: SceneEntityCfg,
random_texture_prob: float = 0.5,
texture_path: Optional[str] = None,
base_color_range: Optional[tuple[list[float], list[float]]] = None,
metallic_range: Optional[tuple[float, float]] = None,
roughness_range: Optional[tuple[float, float]] = None,
ior_range: Optional[tuple[float, float]] = None,
texture_path: str | None = None,
base_color_range: tuple[list[float], list[float]] | None = None,
metallic_range: tuple[float, float] | None = None,
roughness_range: tuple[float, float] | None = None,
ior_range: tuple[float, float] | None = None,
):
from embodichain.lab.sim.utility import is_rt_enabled

Expand Down
Loading
Loading