Skip to content

Commit cdf8436

Browse files
committed
DROID training
1 parent 61dd0e6 commit cdf8436

File tree

3 files changed

+41
-13
lines changed

3 files changed

+41
-13
lines changed

examples/droid/README_train.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@ uv sync --group rlds
1515

1616
## Download DROID dataset
1717

18-
You can download a (slightly outdated) version of DROID with the following command (after installing the `gsutil` google cloud CLI):
18+
You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
1919
```
20-
gsutil -m cp -r gs://gresearch/robotics/droid <your_download_path>
20+
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>
2121
```
2222

23-
Note that this version of DROID is slightly outdated: it only contains a partial set of language annotations (~30k episodes).
24-
Please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com) to get access to the most up-to-date version of the DROID RLDS dataset (with language annotations on 75k episodes)!
25-
(sorry, we are working on updating the version on the official bucket).
26-
2723
You will need 1.8TB of disk storage to download the DROID RLDS dataset.
2824

2925
## Run

src/openpi/shared/download.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import concurrent.futures
22
import datetime
3-
import getpass
43
import logging
54
import os
65
import pathlib
@@ -17,16 +16,13 @@
1716

1817
# Environment variable to control cache directory path, ~/.cache/openpi will be used by default.
1918
_OPENPI_DATA_HOME = "OPENPI_DATA_HOME"
19+
DEFAULT_CACHE_DIR = "~/.cache/openpi"
2020

2121
logger = logging.getLogger(__name__)
2222

2323

2424
def get_cache_dir() -> pathlib.Path:
25-
default_dir = "~/.cache/openpi"
26-
if os.path.exists("/mnt/weka"): # noqa: PTH110
27-
default_dir = f"/mnt/weka/{getpass.getuser()}/.cache/openpi"
28-
29-
cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, default_dir)).expanduser().resolve()
25+
cache_dir = pathlib.Path(os.getenv(_OPENPI_DATA_HOME, DEFAULT_CACHE_DIR)).expanduser().resolve()
3026
cache_dir.mkdir(parents=True, exist_ok=True)
3127
_set_folder_permission(cache_dir)
3228
return cache_dir

src/openpi/training/config.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig
379379
)
380380

381381
data_transforms = _transforms.Group(
382-
inputs=[droid_policy.DroidInputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
382+
inputs=[droid_policy.DroidInputs(model_type=model_config.model_type)],
383383
outputs=[droid_policy.DroidOutputs()],
384384
)
385385

@@ -837,6 +837,42 @@ def __post_init__(self) -> None:
837837
keep_period=20_000,
838838
num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally
839839
),
840+
TrainConfig(
841+
# This config is for fine-tuning pi05 on the *full* DROID dataset.
842+
# We use RLDS data loading to make training on this large dataset tractable.
843+
# For fine-tuning on your own DROID dataset, see below.
844+
name="pi05_full_droid_finetune",
845+
model=pi0.Pi0Config(
846+
pi05=True,
847+
action_dim=32,
848+
action_horizon=16,
849+
),
850+
data=RLDSDroidDataConfig(
851+
repo_id="droid",
852+
# Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).
853+
rlds_data_dir="/mnt/pi-data/kevin",
854+
action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
855+
assets=AssetsConfig(
856+
assets_dir="gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/assets/",
857+
asset_id="droid",
858+
),
859+
),
860+
weight_loader=weight_loaders.CheckpointWeightLoader(
861+
"gs://openpi-assets-preview/checkpoints/pi05_may21_280k_v1/params"
862+
),
863+
lr_schedule=_optimizer.CosineDecaySchedule(
864+
warmup_steps=1_000,
865+
peak_lr=5e-5,
866+
decay_steps=1_000_000,
867+
decay_lr=5e-5,
868+
),
869+
num_train_steps=100_000,
870+
batch_size=256,
871+
log_interval=100,
872+
save_interval=5000,
873+
keep_period=10_000,
874+
num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally
875+
),
840876
TrainConfig(
841877
# This config is for fine-tuning pi05-DROID on a custom (smaller) DROID dataset.
842878
# Here, we use LeRobot data format (like for all other fine-tuning examples)

0 commit comments

Comments
 (0)