Skip to content

Commit c522284

Browse files
authored
Merge pull request #7 from OpenDriveLab/dev
Release Train-Deploy Alignment module and advantage labels
2 parents 18b056f + e8a54a7 commit c522284

File tree

1,172 files changed

+349094
-49
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,172 files changed

+349094
-49
lines changed

README.md

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
- **[Model Arithmetic](#model-arithmetic)**: A weight-space merging strategy that combines models trained on different data subsets, efficiently capturing diverse knowledge without architectural complexity. **[Released]**
2323
- **[Stage Advantage](#stage-advantage)**: A stage-aware advantage estimator that provides stable, dense progress signals for policy training. **[Released]**
24-
- **[Train-Deploy Alignment](#train-deploy-alignment-coming-soon)**: Bridges the distribution gap via spatio-temporal augmentation, heuristic DAgger corrections, and temporal chunk-wise smoothing. **[Coming Soon]**
24+
- **[Train-Deploy Alignment](#train-deploy-alignment)**: Bridges the distribution gap via spatio-temporal augmentation, heuristic DAgger corrections, and temporal chunk-wise smoothing. **[Released]**
2525

2626
χ₀ enables two sets of dual-arm robots to collaboratively orchestrate long-horizon garment manipulation — flattening, folding, and hanging — surpassing the state-of-the-art $\pi_{0.5}$ baseline by approximately 250% in success rate, with `only 20 hours of data and 8 A100 GPUs`.
2727

@@ -47,13 +47,15 @@ https://github.com/user-attachments/assets/3f5f0c48-ff3f-4b9b-985b-59ad0b2ea97c
4747
- [Workflow](#workflow)
4848
- [Quick Start](#quick-start)
4949
- [Stage Advantage](#stage-advantage)
50-
- [Train-Deploy Alignment (Coming Soon)](#train-deploy-alignment-coming-soon)
50+
- [Train-Deploy Alignment](#train-deploy-alignment)
5151
- [Citation](#licenseandcitation)
5252
- [Troubleshooting](#troubleshooting)
5353
- [Links and Community](#links-and-community)
5454

5555
## Update
5656

57+
- [Feb 15 2026] Stage Advantage **advantage labels** (`Task_A/advantage/`) released on [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) and [ModelScope](https://www.modelscope.cn/datasets/OpenDriveLab/Kai0).
58+
- [Feb 15 2026] Release of the **Train-Deploy Alignment** module: data augmentation (time scaling, space mirroring), DAgger data collection, inference with temporal smoothing/ensembling and RTC, and HDF5-to-LeRobot conversion.
5759
- [Feb 14 2026] Release of the **Stage Advantage** module: advantage estimator training, evaluation, GT labeling, and AWBC training pipeline.
5860
- [Feb 10 2026] Initial release of the **Model Arithmetic** module with support for both JAX and PyTorch checkpoints (not tested thoroughly).
5961
- [Feb 10 2026] χ₀ paper released.
@@ -80,7 +82,7 @@ Non-edge components (e.g., Policy Training, Model Arithmetic) have been tested o
8082

8183
### Hardware
8284

83-
For real-robot deployment (dual-arm setup, cameras, and table layout), see **[Hardware Setup & 3D Print Files](setup/README.md)**. That document covers supported platforms (Agilex Piper for FlattenFold / TeeShirtSort, ARX X5 for HangCloth), Intel RealSense D435i camera placement, 3D-printed grippers and mounts with usage notes, and inference host GPU (RTX 4090 in Ubuntu 20.04).
85+
For real-robot deployment (dual-arm setup, cameras, and table layout), see **[Hardware Setup & 3D Print Files](setup/README.md)**. That document covers supported platforms (Agilex Piper for Task_A / Task_B, ARX X5 for Task_C), Intel RealSense D435i camera placement, 3D-printed grippers and mounts with usage notes, and inference host GPU (RTX 4090 in Ubuntu 20.04).
8486

8587
## Installation
8688

@@ -116,11 +118,11 @@ Download the Kai0 dataset so it is available under `./data` for training and eva
116118
python scripts/download_dataset.py
117119
```
118120

119-
This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset).
121+
This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (Task_A, Task_B, Task_C). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset).
120122

121123
### 2. Download checkpoints (optional, for testing)
122124

123-
We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main).
125+
We provide **one best model per task** (Task_A, Task_B, Task_C) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main).
124126

125127
From the repository root, you can download all best-model checkpoints to `./checkpoints` with:
126128

@@ -131,7 +133,7 @@ python scripts/download_checkpoints.py
131133
To download only specific tasks or use a custom path, run:
132134

133135
```bash
134-
python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_checkpoints
136+
python scripts/download_checkpoints.py --tasks Task_A Task_C --local-dir ./my_checkpoints
135137
```
136138

137139
After download, set `weight_loader` in the training config to the path of the corresponding checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.
@@ -144,7 +146,7 @@ After the dataset is in `./data`, you can run **normal π₀.₅ full fine-tunin
144146

145147
Edit [`src/openpi/training/config.py`](src/openpi/training/config.py) (around lines 1173–1226) for the task(s) you need:
146148

147-
- **`repo_id`**: set to the **absolute path** to the dataset subset, e.g. `<path_to_repo_root>/data/FlattenFold/base`, `<path_to_repo_root>/data/TeeShirtSort/base`, or `<path_to_repo_root>/data/HangCloth/base`.
149+
- **`repo_id`**: set to the **absolute path** to the dataset subset, e.g. `<path_to_repo_root>/data/Task_A/base`, `<path_to_repo_root>/data/Task_B/base`, or `<path_to_repo_root>/data/Task_C/base`.
148150
- **`weight_loader`**: set to the path of your **π₀.₅ base checkpoint** — either the best model you downloaded in step 2 above, or openpi’s pretrained π₀.₅ checkpoint.
149151

150152
Config names to use: e.g. `pi05_flatten_fold_normal`
@@ -210,8 +212,8 @@ Checkpoints are written to the config’s checkpoint directory. You can then use
210212
- [x] kai0 oracle: training and inference code with non-advantage data of three tasks
211213
- [x] Model Arithmetic: code of different baselines for weight-space interpolation
212214
- [x] Stage Advantage: code, data (advantage labels), and checkpoints
213-
- [ ] HuggingFace & ModelScope: upload Stage Advantage data and checkpoints — **Feb 14**
214-
- [ ] Train-Deploy Alignment — **Feb 14**
215+
- [x] Train-Deploy Alignment: data augmentation, DAgger, inference (temporal smoothing, ensembling, RTC)
216+
- [x] HuggingFace & ModelScope: Stage Advantage data (`Task_A/advantage/`) and checkpoints uploaded
215217

216218
## Model Arithmetic
217219

@@ -300,7 +302,7 @@ For a ready-to-use script with environment setup (conda/venv activation, DDP con
300302
**Stage 2 — Advantage Estimation on New Data**: Use the trained estimator to label datasets with predicted advantage values.
301303

302304
```bash
303-
uv run python stage_advantage/annotation/eval.py Flatten-Fold KAI0 /path/to/dataset
305+
uv run python stage_advantage/annotation/eval.py Task-A KAI0 /path/to/dataset
304306
```
305307

306308
For a ready-to-use script with environment setup and status logging, see `stage_advantage/annotation/eval.sh`.
@@ -315,14 +317,56 @@ For a ready-to-use script with environment setup and automatic log management, s
315317

316318
For the full pipeline details, configuration instructions, and all parameters, see [`stage_advantage/README.md`](stage_advantage/README.md).
317319

318-
## Train-Deploy Alignment (Coming Soon)
320+
## Train-Deploy Alignment
319321

320-
Train-Deploy Alignment bridges the distribution gap between training and real-world deployment through:
321-
- **Spatio-temporal augmentation**: Data augmentation including space mirroring and time scaling for dual-arm setups.
322-
- **Heuristic DAgger corrections**: Interactive on-robot data collection for iterative policy improvement.
323-
- **Temporal chunk-wise smoothing**: Smoothed action execution to reduce jitter during deployment.
322+
Train-Deploy Alignment bridges the distribution gap between training and real-world deployment through three sub-modules:
324323

325-
**This module is currently under refinement and will be released soon.**
324+
- **Data Augmentation** (`train_deploy_alignment/data_augment/`): Time scaling (frame extraction at configurable rates), space mirroring (left/right arm swap + video flip), dataset merging, and HDF5-to-LeRobot format conversion.
325+
- **DAgger** (`train_deploy_alignment/dagger/`): Policy-in-the-loop data collection for both Agilex Piper and ARX X5 platforms. Operators run inference, switch to DAgger mode for human corrections, and save episodes (HDF5 + optional videos + intervention labels).
326+
- **Inference** (`train_deploy_alignment/inference/`): Deployment code for Agilex and ARX robots with multiple execution modes — synchronous, temporal smoothing, temporal ensembling, and **RTC (real-time chunking)**. Uses a two-machine setup (GPU policy server + robot IPC client).
327+
328+
### Quick Start
329+
330+
**Data Augmentation — Time scaling:**
331+
332+
```bash
333+
python train_deploy_alignment/data_augment/time_scaling.py \
334+
--src_path /path/to/source --tgt_path /path/to/extracted --repo_id extracted_dataset \
335+
--extraction_factor 2
336+
```
337+
338+
**Data Augmentation — Space mirroring (mirror + merge):**
339+
340+
```bash
341+
python train_deploy_alignment/data_augment/space_mirroring.py full \
342+
--src-path /path/to/original --mirror-path /path/to/mirrored --merge-path /path/to/merged \
343+
--repo-id my_dataset
344+
```
345+
346+
**DAgger — Agilex:** Start the policy server on the GPU host, then on the IPC:
347+
348+
```bash
349+
conda activate kai0_inference
350+
python train_deploy_alignment/dagger/agilex/agilex_openpi_dagger_collect.py \
351+
--host <gpu_host_ip> --port 8000 --ctrl_type joint --use_temporal_smoothing --chunk_size 50 \
352+
--dataset_name <your_dataset_name>
353+
```
354+
355+
**Inference — Agilex (temporal smoothing):** Start the policy server on the GPU host, then on the IPC:
356+
357+
```bash
358+
conda activate kai0_inference
359+
python inference/agilex_inference_openpi_temporal_smoothing.py \
360+
--host <gpu_host_ip> --port 8000 --ctrl_type joint --use_temporal_smoothing --chunk_size 50
361+
```
362+
363+
**Inference — ARX (RTC mode):** Start the policy server with an RTC config, then on the IPC:
364+
365+
```bash
366+
python inference/arx_openpi_inference_rtc.py --host <gpu_host_ip> --port 8000 --rtc_mode --chunk_size 50
367+
```
368+
369+
For full setup instructions (IPC environment, CAN, ROS/ROS2, platform-specific details), see [`train_deploy_alignment/README.md`](train_deploy_alignment/README.md).
326370

327371
## License and Citation
328372

setup/README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ Quick reference for deploying and debugging hardware for the supported task plat
66

77
## Table of Contents
88

9-
- [1. FlattenFold / TeeShirtSort (Agilex Piper)](#1-flattenfold--teeshirtsort-agilex-piper)
10-
- [2. HangCloth (ARX X5)](#2-hangcloth-arx-x5)
9+
- [1. Task_A / Task_B (Agilex Piper)](#1-task_a--task_b-agilex-piper)
10+
- [2. Task_C (ARX X5)](#2-task_c-arx-x5)
1111
- [3. Inference Host](#3-inference-host)
1212

1313
---
1414

15-
## 1. FlattenFold / TeeShirtSort (Agilex Piper)
15+
## 1. Task_A / Task_B (Agilex Piper)
1616

17-
**Directories:** `FlattenFold/`, `TeeShirtSort/`
17+
**Directories:** `Task_A/`, `Task_B/`
1818

1919
### 1.1 Components
2020

@@ -24,7 +24,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
2424
| Cameras | Intel RealSense D435 (triple-camera setup) |
2525
| Printed parts | Left/right wrist camera mounts, center camera mount, center camera base |
2626

27-
### 1.2 FlattenFold Layout
27+
### 1.2 Task_A Layout
2828

2929
| Parameter | Value |
3030
|-----------|-------|
@@ -37,7 +37,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
3737
| Right primary arm → table front edge | 12 cm |
3838
| Left–right primary arm center distance | 39 cm |
3939

40-
### 1.3 TeeShirtSort Layout (demoA-style)
40+
### 1.3 Task_B Layout (demoA-style)
4141

4242
| Parameter | Value |
4343
|-----------|-------|
@@ -50,7 +50,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
5050
| Right primary arm → table front edge | 11 cm |
5151
| Left–right primary arm center distance | 40 cm |
5252

53-
### 1.4 3D Models — Usage (FlattenFold / TeeShirtSort)
53+
### 1.4 3D Models — Usage (Task_A / Task_B)
5454

5555
#### Gripper (end-effector)
5656

@@ -77,9 +77,9 @@ Quick reference for deploying and debugging hardware for the supported task plat
7777

7878
---
7979

80-
## 2. HangCloth (ARX X5)
80+
## 2. Task_C (ARX X5)
8181

82-
**Directory:** `HangCloth/`
82+
**Directory:** `Task_C/`
8383

8484
### 2.1 Components
8585

@@ -102,7 +102,7 @@ Quick reference for deploying and debugging hardware for the supported task plat
102102
| Right primary arm → table front edge | 11 cm |
103103
| Left–right primary arm center distance | 53 cm |
104104

105-
### 2.3 3D Models — Usage (HangCloth)
105+
### 2.3 3D Models — Usage (Task_C)
106106

107107
#### Grippers (secondary arms)
108108

src/openpi/models/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class ModelType(enum.Enum):
3333
PI0 = "pi0"
3434
PI0_FAST = "pi0_fast"
3535
PI05 = "pi05"
36+
PI0_RTC = "pi0_rtc"
37+
PI05_RTC = "pi05_rtc"
3638

3739

3840
# The model always expects these images

src/openpi/models/pi0_config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
from openpi.models.pi0 import Pi0
16+
from openpi.models.pi0_rtc import Pi0RTC
1617

1718

1819
@dataclasses.dataclass(frozen=True)
@@ -107,6 +108,33 @@ def get_freeze_filter(self) -> nnx.filterlib.Filter:
107108
return nnx.Nothing
108109
return nnx.All(*filters)
109110

111+
112+
@dataclasses.dataclass(frozen=True)
113+
class Pi0RTCConfig(Pi0Config):
114+
"""Config for Pi0RTC (real-time control) model. Uses same architecture as Pi0/Pi05 but sample_actions supports
115+
prev_action_chunk, inference_delay, execute_horizon for RTC guidance. Use this config when serving
116+
for RTC inference (e.g. agilex_inference_openpi_rtc.py). Set pi05=True for Pi05-based RTC (model_type PI05_RTC)."""
117+
118+
@property
119+
@override
120+
def model_type(self) -> _model.ModelType:
121+
return _model.ModelType.PI05_RTC if self.pi05 else _model.ModelType.PI0_RTC
122+
123+
@override
124+
def create(self, rng: at.KeyArrayLike) -> "Pi0RTC":
125+
from openpi.models.pi0_rtc import Pi0RTC
126+
127+
return Pi0RTC(self, rngs=nnx.Rngs(rng))
128+
129+
@override
130+
def load_pytorch(self, train_config, weight_path: str):
131+
"""RTC model is JAX-only; use a JAX checkpoint with serve_policy and Pi0RTCConfig."""
132+
raise NotImplementedError(
133+
"Pi0RTC is only supported with JAX checkpoints. Use a checkpoint saved from OpenPi JAX training "
134+
"(params directory, not model.safetensors) and serve with --policy.config=pi05_rtc_flatten_fold_inference (or your RTC config name)."
135+
)
136+
137+
110138
@dataclasses.dataclass(frozen=True)
111139
class AdvantageEstimatorConfig(Pi0Config):
112140
# * Custom

src/openpi/models/pi0_rtc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def rtc_step(carry):
322322
x_t_for_denoise = x_t
323323
if mask_prefix_delay and provided_dim > 0:
324324
mask_time = (jnp.arange(self.action_horizon) < d).astype(bool)[None, :, None]
325-
# 仅覆盖提供的维度,其余保持 x_t 原值
325+
# Overwrite only the provided dims in the delay prefix; leave the rest as x_t.
326326
overwrite = jnp.where(mask_time, prev_chunk[..., :provided_dim], x_t_for_denoise[..., :provided_dim])
327327
x_t_for_denoise = x_t_for_denoise.at[..., :provided_dim].set(overwrite)
328328

src/openpi/policies/agilex_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class AgilexInputs(transforms.DataTransformFn):
5252
mask_state: bool = False
5353

5454
def __call__(self, data: dict) -> dict:
55-
# We only mask padding for pi0 model, not pi0-FAST
56-
mask_padding = self.model_type == _model.ModelType.PI0
55+
# We only mask padding for pi0/pi0_rtc model, not pi05/pi05_rtc or pi0-FAST
56+
mask_padding = self.model_type in (_model.ModelType.PI0, _model.ModelType.PI0_RTC)
5757

5858
in_images = data["images"]
5959

src/openpi/policies/arx_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class ARXInputs(transforms.DataTransformFn):
5252
mask_state: bool = False
5353

5454
def __call__(self, data: dict) -> dict:
55-
# We only mask padding for pi0 model, not pi0-FAST
56-
mask_padding = self.model_type == _model.ModelType.PI0
55+
# We only mask padding for pi0/pi0_rtc model, not pi05/pi05_rtc or pi0-FAST
56+
mask_padding = self.model_type in (_model.ModelType.PI0, _model.ModelType.PI0_RTC)
5757

5858
in_images = data["images"]
5959

src/openpi/policies/droid_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __call__(self, data: dict) -> dict:
4545
wrist_image = _parse_image(data["observation/wrist_image_left"])
4646

4747
match self.model_type:
48-
case _model.ModelType.PI0 | _model.ModelType.PI05:
48+
case _model.ModelType.PI0 | _model.ModelType.PI05 | _model.ModelType.PI0_RTC | _model.ModelType.PI05_RTC:
4949
names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
5050
images = (base_image, wrist_image, np.zeros_like(base_image))
5151
image_masks = (np.True_, np.True_, np.False_)

src/openpi/training/config.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ModelTransformFactory(GroupFactory):
115115

116116
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
117117
match model_config.model_type:
118-
case _model.ModelType.PI0:
118+
case _model.ModelType.PI0 | _model.ModelType.PI0_RTC:
119119
return _transforms.Group(
120120
inputs=[
121121
_transforms.InjectDefaultPrompt(self.default_prompt),
@@ -126,7 +126,7 @@ def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
126126
_transforms.PadStatesAndActions(model_config.action_dim),
127127
],
128128
)
129-
case _model.ModelType.PI05:
129+
case _model.ModelType.PI05 | _model.ModelType.PI05_RTC:
130130
assert isinstance(model_config, pi0_config.Pi0Config)
131131
return _transforms.Group(
132132
inputs=[
@@ -187,7 +187,7 @@ def create_base_config(self, assets_dirs: pathlib.Path, model_config: _model.Bas
187187
repo_id=repo_id,
188188
asset_id=asset_id,
189189
norm_stats=self._load_norm_stats(epath.Path(self.assets.assets_dir or assets_dirs), asset_id),
190-
use_quantile_norm=model_config.model_type != ModelType.PI0,
190+
use_quantile_norm=model_config.model_type not in (ModelType.PI0, ModelType.PI0_RTC),
191191
)
192192

193193
def _load_norm_stats(self, assets_dir: epath.Path, asset_id: str | None) -> dict[str, _transforms.NormStats] | None:
@@ -1371,6 +1371,23 @@ def __post_init__(self) -> None:
13711371
num_workers=8,
13721372
batch_size=256,
13731373
),
1374+
1375+
#**************************FlattenFold RTC Inference*******************************
1376+
# Use this config when serving the policy for agilex_inference_openpi_rtc.py (JAX checkpoints only).
1377+
TrainConfig(
1378+
name="pi05_rtc_flatten_fold_inference",
1379+
model=pi0_config.Pi0RTCConfig(pi05=True),
1380+
data=LerobotAgilexDataConfig(
1381+
repo_id="<path_to_repo_root>/data/FlattenFold/base",
1382+
default_prompt="Flatten and fold the cloth.",
1383+
use_delta_joint_actions=False,
1384+
),
1385+
weight_loader=weight_loaders.CheckpointWeightLoader("<path_to/pi05_base/checkpoint>"),
1386+
num_train_steps=100_000,
1387+
keep_period=5000,
1388+
num_workers=8,
1389+
batch_size=256,
1390+
),
13741391
# RoboArena & PolaRiS configs.
13751392
*roboarena_config.get_roboarena_configs(),
13761393
*polaris_config.get_polaris_configs(),

0 commit comments

Comments
 (0)