Skip to content

Commit 9cc298e

Browse files
glvov-bdaigarylvovDavid Hoellerjsmith-bdai
authored
Adds image extracted features observation term and cartpole examples for it (#1191)
# Description This adds an observation term to be able to easily extract features from the images, and adds a cartpole example of using this new term. The new ResNet18 cartpole converges in less than 100 epochs. ## Type of change <!-- As you go through the list, delete the ones that are not applicable. --> - New feature (non-breaking change which adds functionality) - This change requires a documentation update ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there I will update the version in the changelog and extension.toml after approval prior to merging in due to it causing merge conflicts when main updates --------- Signed-off-by: glvov-bdai <[email protected]> Signed-off-by: garylvov <[email protected]> Co-authored-by: garylvov <[email protected]> Co-authored-by: garylvov <[email protected]> Co-authored-by: David Hoeller <[email protected]> Co-authored-by: James Smith <[email protected]>
1 parent cace5c5 commit 9cc298e

File tree

12 files changed

+336
-9
lines changed

12 files changed

+336
-9
lines changed

CONTRIBUTORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Guidelines for modifications:
4343
* Chenyu Yang
4444
* David Yang
4545
* Dorsa Rohani
46+
* Felix Yu
4647
* Gary Lvov
4748
* Giulio Romualdi
4849
* HoJin Jeon

docs/source/overview/environments.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
6161
| | | |
6262
| | |cartpole-depth-direct-link|| |
6363
+------------------+-----------------------------+-------------------------------------------------------------------------+
64+
| |cartpole| | |cartpole-resnet-link| | Move the cart to keep the pole upwards in the classic cartpole control |
65+
| | | based off of features extracted from perceptive inputs with pre-trained |
66+
| | |cartpole-theia-link| | frozen vision encoders |
67+
+------------------+-----------------------------+-------------------------------------------------------------------------+
6468

6569
.. |humanoid| image:: ../_static/tasks/classic/humanoid.jpg
6670
.. |ant| image:: ../_static/tasks/classic/ant.jpg
@@ -69,8 +73,11 @@ Classic environments that are based on IsaacGymEnvs implementation of MuJoCo-sty
6973
.. |humanoid-link| replace:: `Isaac-Humanoid-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/humanoid/humanoid_env_cfg.py>`__
7074
.. |ant-link| replace:: `Isaac-Ant-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/ant/ant_env_cfg.py>`__
7175
.. |cartpole-link| replace:: `Isaac-Cartpole-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_env_cfg.py>`__
72-
.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
73-
.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-Camera-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
76+
.. |cartpole-rgb-link| replace:: `Isaac-Cartpole-RGB-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
77+
.. |cartpole-depth-link| replace:: `Isaac-Cartpole-Depth-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
78+
.. |cartpole-resnet-link| replace:: `Isaac-Cartpole-RGB-ResNet18-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
79+
.. |cartpole-theia-link| replace:: `Isaac-Cartpole-RGB-TheiaTiny-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/cartpole_camera_env_cfg.py>`__
80+
7481

7582
.. |humanoid-direct-link| replace:: `Isaac-Humanoid-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/humanoid/humanoid_env.py>`__
7683
.. |ant-direct-link| replace:: `Isaac-Ant-Direct-v0 <https://github.com/isaac-sim/IsaacLab/blob/main/source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/ant/ant_env.py>`__

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ extra_standard_library = [
3636
"toml",
3737
"trimesh",
3838
"tqdm",
39+
"torchvision",
40+
"transformers",
41+
"einops" # Needed for transformers, doesn't always auto-install
3942
]
4043
# Imports from Isaac Sim and Omniverse
4144
known_third_party = [

source/extensions/omni.isaac.lab/config/extension.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.27.6"
4+
5+
version = "0.27.7"
56

67
# Description
78
title = "Isaac Lab framework for Robot Learning"

source/extensions/omni.isaac.lab/docs/CHANGELOG.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
Changelog
22
---------
33

4+
5+
0.27.7 (2024-10-28)
6+
~~~~~~~~~~~~~~~~~~~
7+
8+
Added
9+
^^^^^
10+
11+
* Added frozen encoder feature extraction observation space with ResNet and Theia
12+
13+
414
0.27.6 (2024-10-25)
515
~~~~~~~~~~~~~~~~~~~
616

source/extensions/omni.isaac.lab/omni/isaac/lab/envs/mdp/observations.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717
import omni.isaac.lab.utils.math as math_utils
1818
from omni.isaac.lab.assets import Articulation, RigidObject
1919
from omni.isaac.lab.managers import SceneEntityCfg
20+
from omni.isaac.lab.managers.manager_base import ManagerTermBase
21+
from omni.isaac.lab.managers.manager_term_cfg import ObservationTermCfg
2022
from omni.isaac.lab.sensors import Camera, Imu, RayCaster, RayCasterCamera, TiledCamera
2123

2224
if TYPE_CHECKING:
2325
from omni.isaac.lab.envs import ManagerBasedEnv, ManagerBasedRLEnv
2426

27+
2528
"""
2629
Root state.
2730
"""
@@ -273,6 +276,134 @@ def image(
273276
return images.clone()
274277

275278

279+
class image_features(ManagerTermBase):
280+
"""Extracted image features from a pre-trained frozen encoder.
281+
282+
This method calls the :meth:`image` function to retrieve images, and then performs
283+
inference on those images.
284+
"""
285+
286+
def __init__(self, cfg: ObservationTermCfg, env: ManagerBasedEnv):
287+
super().__init__(cfg, env)
288+
from torchvision import models
289+
from transformers import AutoModel
290+
291+
def create_theia_model(model_name):
292+
return {
293+
"model": (
294+
lambda: AutoModel.from_pretrained(f"theaiinstitute/{model_name}", trust_remote_code=True)
295+
.eval()
296+
.to("cuda:0")
297+
),
298+
"preprocess": lambda img: (img - torch.amin(img, dim=(1, 2), keepdim=True)) / (
299+
torch.amax(img, dim=(1, 2), keepdim=True) - torch.amin(img, dim=(1, 2), keepdim=True)
300+
),
301+
"inference": lambda model, images: model.forward_feature(
302+
images, do_rescale=False, interpolate_pos_encoding=True
303+
),
304+
}
305+
306+
def create_resnet_model(resnet_name):
307+
return {
308+
"model": lambda: getattr(models, resnet_name)(pretrained=True).eval().to("cuda:0"),
309+
"preprocess": lambda img: (
310+
img.permute(0, 3, 1, 2) # Convert [batch, height, width, 3] -> [batch, 3, height, width]
311+
- torch.tensor([0.485, 0.456, 0.406], device=img.device).view(1, 3, 1, 1)
312+
) / torch.tensor([0.229, 0.224, 0.225], device=img.device).view(1, 3, 1, 1),
313+
"inference": lambda model, images: model(images),
314+
}
315+
316+
# List of Theia models
317+
theia_models = [
318+
"theia-tiny-patch16-224-cddsv",
319+
"theia-tiny-patch16-224-cdiv",
320+
"theia-small-patch16-224-cdiv",
321+
"theia-base-patch16-224-cdiv",
322+
"theia-small-patch16-224-cddsv",
323+
"theia-base-patch16-224-cddsv",
324+
]
325+
326+
# List of ResNet models
327+
resnet_models = ["resnet18", "resnet34", "resnet50", "resnet101"]
328+
329+
self.default_model_zoo_cfg = {}
330+
331+
# Add Theia models to the zoo
332+
for model_name in theia_models:
333+
self.default_model_zoo_cfg[model_name] = create_theia_model(model_name)
334+
335+
# Add ResNet models to the zoo
336+
for resnet_name in resnet_models:
337+
self.default_model_zoo_cfg[resnet_name] = create_resnet_model(resnet_name)
338+
339+
self.model_zoo_cfg = self.default_model_zoo_cfg
340+
self.model_zoo = {}
341+
342+
def __call__(
343+
self,
344+
env: ManagerBasedEnv,
345+
sensor_cfg: SceneEntityCfg = SceneEntityCfg("tiled_camera"),
346+
data_type: str = "rgb",
347+
convert_perspective_to_orthogonal: bool = False,
348+
model_zoo_cfg: dict | None = None,
349+
model_name: str = "ResNet18",
350+
model_device: str | None = "cuda:0",
351+
reset_model: bool = False,
352+
) -> torch.Tensor:
353+
"""Extracted image features from a pre-trained frozen encoder.
354+
355+
Args:
356+
env: The environment.
357+
sensor_cfg: The sensor configuration to poll. Defaults to SceneEntityCfg("tiled_camera").
358+
data_type: THe sensor configuration datatype. Defaults to "rgb".
359+
convert_perspective_to_orthogonal: Whether to orthogonalize perspective depth images.
360+
This is used only when the data type is "distance_to_camera". Defaults to False.
361+
model_zoo_cfg: Map from model name to model configuration dictionary. Each model
362+
configuration dictionary should include the following entries:
363+
- "model": A callable that returns the model when invoked without arguments.
364+
- "preprocess": A callable that processes the images and returns the preprocessed results.
365+
- "inference": A callable that, when given the model and preprocessed images,
366+
returns the extracted features.
367+
model_name: The name of the model to use for inference. Defaults to "ResNet18".
368+
model_device: The device to store and infer models on. This can be used help offload
369+
computation from the main environment GPU. Defaults to "cuda:0".
370+
reset_model: Initialize the model even if it already exists. Defaults to False.
371+
372+
Returns:
373+
torch.Tensor: the image features, on the same device as the image
374+
"""
375+
if model_zoo_cfg is not None: # use other than default
376+
self.model_zoo_cfg.update(model_zoo_cfg)
377+
378+
if model_name not in self.model_zoo or reset_model:
379+
# The following allows to only load a desired subset of a model zoo into GPU memory
380+
# as it becomes needed, in a "lazy" evaluation.
381+
print(f"[INFO]: Adding {model_name} to the model zoo")
382+
self.model_zoo[model_name] = self.model_zoo_cfg[model_name]["model"]()
383+
384+
if model_device is not None and self.model_zoo[model_name].device != model_device:
385+
# want to offload vision model inference to another device
386+
self.model_zoo[model_name] = self.model_zoo[model_name].to(model_device)
387+
388+
images = image(
389+
env=env,
390+
sensor_cfg=sensor_cfg,
391+
data_type=data_type,
392+
convert_perspective_to_orthogonal=convert_perspective_to_orthogonal,
393+
normalize=True, # want this for training stability
394+
)
395+
396+
image_device = images.device
397+
398+
if model_device is not None:
399+
images = images.to(model_device)
400+
401+
proc_images = self.model_zoo_cfg[model_name]["preprocess"](images)
402+
features = self.model_zoo_cfg[model_name]["inference"](self.model_zoo[model_name], proc_images)
403+
404+
return features.to(image_device).clone()
405+
406+
276407
"""
277408
Actions.
278409
"""

source/extensions/omni.isaac.lab_tasks/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.10.10"
4+
version = "0.10.12"
55

66
# Description
77
title = "Isaac Lab Environments"

source/extensions/omni.isaac.lab_tasks/docs/CHANGELOG.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
Changelog
22
---------
33

4+
0.10.12 (2024-10-28)
5+
~~~~~~~~~~~~~~~~~~~~
6+
7+
Changed
8+
^^^^^^^
9+
10+
* Changed manager-based vision cartpole environment names from Isaac-Cartpole-RGB-Camera-v0
11+
and Isaac-Cartpole-Depth-Camera-v0 to Isaac-Cartpole-RGB-v0 and Isaac-Cartpole-Depth-v0
12+
13+
0.10.11 (2024-10-28)
14+
~~~~~~~~~~~~~~~~~~~~
15+
16+
Added
17+
^^^^^
18+
19+
* Added feature extracted observation cartpole examples.
20+
421
0.10.10 (2024-10-25)
522
~~~~~~~~~~~~~~~~~~~~
623

source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/direct/shadow_hand/feature_extractor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
import torch
99
import torch.nn as nn
10-
1110
import torchvision
1211

1312
from omni.isaac.lab.sensors import save_images_to_file

source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/manager_based/classic/cartpole/__init__.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
import gymnasium as gym
1111

1212
from . import agents
13-
from .cartpole_camera_env_cfg import CartpoleDepthCameraEnvCfg, CartpoleRGBCameraEnvCfg
13+
from .cartpole_camera_env_cfg import (
14+
CartpoleDepthCameraEnvCfg,
15+
CartpoleResNet18CameraEnvCfg,
16+
CartpoleRGBCameraEnvCfg,
17+
CartpoleTheiaTinyCameraEnvCfg,
18+
)
1419
from .cartpole_env_cfg import CartpoleEnvCfg
1520

1621
##
@@ -31,7 +36,7 @@
3136
)
3237

3338
gym.register(
34-
id="Isaac-Cartpole-RGB-Camera-v0",
39+
id="Isaac-Cartpole-RGB-v0",
3540
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
3641
disable_env_checker=True,
3742
kwargs={
@@ -41,11 +46,31 @@
4146
)
4247

4348
gym.register(
44-
id="Isaac-Cartpole-Depth-Camera-v0",
49+
id="Isaac-Cartpole-Depth-v0",
4550
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
4651
disable_env_checker=True,
4752
kwargs={
4853
"env_cfg_entry_point": CartpoleDepthCameraEnvCfg,
4954
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_camera_ppo_cfg.yaml",
5055
},
5156
)
57+
58+
gym.register(
59+
id="Isaac-Cartpole-RGB-ResNet18-v0",
60+
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
61+
disable_env_checker=True,
62+
kwargs={
63+
"env_cfg_entry_point": CartpoleResNet18CameraEnvCfg,
64+
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_feature_ppo_cfg.yaml",
65+
},
66+
)
67+
68+
gym.register(
69+
id="Isaac-Cartpole-RGB-TheiaTiny-v0",
70+
entry_point="omni.isaac.lab.envs:ManagerBasedRLEnv",
71+
disable_env_checker=True,
72+
kwargs={
73+
"env_cfg_entry_point": CartpoleTheiaTinyCameraEnvCfg,
74+
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_feature_ppo_cfg.yaml",
75+
},
76+
)

0 commit comments

Comments
 (0)