Skip to content

Commit 44ed698

Browse files
authored
Merge pull request #2 from OpenDriveLab/dev
Dev
2 parents e4d3bd3 + a15b6a5 commit 44ed698

File tree

8 files changed

+411
-171
lines changed

8 files changed

+411
-171
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ __pycache__/
99
# Distribution / packaging
1010
.Python
1111
build/
12+
checkpoints/
1213
develop-eggs/
1314
dist/
15+
data/
1416
downloads/
1517
eggs/
1618
.eggs/

README.md

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# χ₀: Resource-Aware Robust Manipulation viaTaming Distributional Inconsistencies
2-
1+
# χ₀: Resource-Aware Robust Manipulation via Taming Distributional Inconsistencies
32

43
<div id="top" align="center">
54

@@ -16,6 +15,7 @@
1615
</div>
1716

1817
χ₀ (**kai0**) is a resource-efficient framework for achieving production-level robustness in robotic manipulation by taming distributional inconsistencies.
18+
<!-- This repository is built on top of [openpi](https://github.com/Physical-Intelligence/openpi), the open-source models and packages for robotics published by the [Physical Intelligence team](https://www.physicalintelligence.company/). -->
1919

2020
χ₀ addresses the systematic distributional shift among the human demonstration distribution ($P_\text{train}$), the inductive bias learned by the policy ($Q_\text{model}$), and the test-time execution distribution ($P_\text{test}$) through three technical modules:
2121

@@ -31,7 +31,7 @@ https://github.com/user-attachments/assets/e662f096-d273-4458-abd4-e12b9685a9bc
3131

3232
## Table of Contents
3333

34-
- [Updates](#updates)
34+
- [Update](#update)
3535
- [Acknowledgement](#acknowledgement)
3636
- [Requirements](#requirements)
3737
- [Compute](#compute)
@@ -75,7 +75,7 @@ This repository is built on top of [openpi](https://github.com/Physical-Intellig
7575

7676
For Model Arithmetic (mixing checkpoints), GPU memory requirements depend on the model size and number of checkpoints being mixed. A single A100 (80GB) is sufficient for most use cases.
7777

78-
The repo has been tested with Ubuntu 22.04.
78+
Non-edge components (e.g., Policy Training, Model Arithmetic) have been tested on Ubuntu 22.04.
7979

8080
### Hardware
8181

@@ -112,15 +112,28 @@ uv pip install safetensors
112112
Download the Kai0 dataset so it is available under `./data` for training and evaluation. From the repository root, run:
113113

114114
```bash
115-
pip install huggingface_hub # if not already installed
116115
python scripts/download_dataset.py
117116
```
118117

119-
This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see [DATASET.md](DATASET.md#step-1-download-the-dataset).
118+
This fetches the full dataset from [Hugging Face](https://huggingface.co/datasets/OpenDriveLab-org/Kai0) into `./data` (FlattenFold, HangCloth, TeeShirtSort). To download only specific tasks or use a custom path, see the [dataset docs](docs/dataset.md#step-1-download-the-dataset).
120119

121120
### 2. Download checkpoints (optional, for testing)
122121

123-
We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main). Download the task folder(s) you need and set `weight_loader` in config to the path of the downloaded checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.
122+
We provide **one best model per task** (FlattenFold, HangCloth, TeeShirtSort) in the [Kai0 repo on Hugging Face](https://huggingface.co/OpenDriveLab-org/Kai0/tree/main).
123+
124+
From the repository root, you can download all best-model checkpoints to `./checkpoints` with:
125+
126+
```bash
127+
python scripts/download_checkpoints.py
128+
```
129+
130+
To download only specific tasks or use a custom path, run:
131+
132+
```bash
133+
python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_checkpoints
134+
```
135+
136+
After download, set `weight_loader` in the training config to the path of the corresponding checkpoint directory (see step 3 below). You can also use openpi’s pretrained π₀.5 checkpoint instead.
124137

125138
### 3. Fine-tune with normal π₀.5
126139

File renamed without changes.

docs/norm_stats_fast.md

100755100644
Lines changed: 136 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,155 @@
1-
# Normalization statistics
1+
## Fast normalization stats computation (`compute_norm_states_fast.py`)
22

3-
Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
3+
This script provides a **fast path** to compute normalization statistics for Kai0 configs by
4+
directly reading local parquet files instead of going through the full data loader. It produces
5+
`norm_stats` that are **compatible with the original openpi pipeline** (same `RunningStats`
6+
implementation and batching scheme).
47

5-
## Reloading normalization statistics
8+
---
69

7-
When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
10+
### When to use this script
811

9-
**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
12+
- You have already **downloaded the dataset locally** (e.g. under `./data`, see
13+
[`docs/dataset.md`](./dataset.md#step-1-download-the-dataset)).
14+
- You have a **training config** in `src/openpi/training/config.py` (e.g.
15+
`pi05_flatten_fold_normal`) and you want to compute `norm_stats` before running
16+
`scripts/train.py`.
17+
- You prefer a **simpler / faster** pipeline compared to the original `compute_norm_stats.py`
18+
while keeping numerically compatible statistics.
1019

11-
```python
12-
TrainConfig(
13-
...
14-
data=LeRobotAlohaDataConfig(
15-
...
16-
assets=AssetsConfig(
17-
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
18-
asset_id="trossen",
19-
),
20-
),
21-
)
22-
```
23-
24-
For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
20+
---
2521

26-
**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
22+
### Script entry point
2723

28-
**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
24+
The script lives at:
2925

26+
- `scripts/compute_norm_states_fast.py`
3027

31-
## Provided Pre-training Normalization Statistics
28+
Main entry:
3229

33-
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `gs://openpi-assets/checkpoints/pi0_fast_base/assets`.
34-
| Robot | Description | Asset ID |
35-
|-------|-------------|----------|
36-
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
37-
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
38-
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
39-
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
40-
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
41-
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
42-
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
43-
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
44-
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
30+
- `main(config_name: str, base_dir: str | None = None, max_frames: int | None = None)`
4531

32+
CLI is handled via [`tyro`](https://github.com/brentyi/tyro), so you call it from the repo root as:
4633

47-
## Pi0 Model Action Space Definitions
48-
49-
Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
34+
```bash
35+
uv run python scripts/compute_norm_states_fast.py --config-name <config_name> [--base-dir <path>] [--max-frames N]
5036
```
51-
"dim_0:dim_5": "left arm joint angles",
52-
"dim_6": "left arm gripper position",
53-
"dim_7:dim_12": "right arm joint angles (for bi-manual only)",
54-
"dim_13": "right arm gripper position (for bi-manual only)",
5537

56-
# For mobile robots:
57-
"dim_14:dim_15": "x-y base velocity (for mobile robots only)",
38+
---
39+
40+
### Arguments
41+
42+
- **`--config-name`** (`str`, required)
43+
- Name of the TrainConfig defined in `src/openpi/training/config.py`, e.g.:
44+
- `pi05_flatten_fold_normal`
45+
- `pi05_tee_shirt_sort_normal`
46+
- `pi05_hang_cloth_normal`
47+
- Internally resolved via `_config.get_config(config_name)`.
48+
49+
- **`--base-dir`** (`str`, optional)
50+
- Base directory containing the parquet data for this config.
51+
- If omitted, the script will read it from `config.data`:
52+
- `data_config = config.data.create(config.assets_dirs, config.model)`
53+
- `base_dir` defaults to `data_config.repo_id`
54+
- This means you can either:
55+
- Set `repo_id` in the config to your local dataset path (e.g.
56+
`<path_to_repo_root>/data/FlattenFold/base`), or
57+
- Keep `repo_id` as-is and pass `--base-dir` explicitly to point to your local copy.
58+
59+
- **`--max-frames`** (`int`, optional)
60+
- If set, stops after processing at most `max_frames` frames across all parquet files.
61+
- Useful for **quick sanity checks** or debugging smaller subsets.
62+
63+
---
64+
65+
### What the script does
66+
67+
1. **Load config**
68+
- Uses `_config.get_config(config_name)` to get the `TrainConfig`.
69+
- Calls `config.data.create(config.assets_dirs, config.model)` to build a data config.
70+
- Reads `action_dim` from `config.model.action_dim`.
71+
72+
2. **Resolve input data directory**
73+
- If `base_dir` is not provided:
74+
- Uses `data_config.repo_id` as the base directory.
75+
- Prints a message like:
76+
- `Auto-detected base directory from config: <base_dir>`
77+
- Verifies that the directory exists.
78+
79+
3. **Scan parquet files**
80+
- Recursively walks `base_dir` and collects all files ending with `.parquet`.
81+
- Sorts them lexicographically for **deterministic ordering** (matches dataset order).
82+
83+
4. **Read and process data**
84+
- For each parquet file:
85+
- Loads it with `pandas.read_parquet`.
86+
- Expects columns:
87+
- `observation.state`
88+
- `action`
89+
- For each row:
90+
- Extracts `state` and `action` arrays.
91+
- Applies:
92+
- `process_state(state, action_dim)`
93+
- `process_actions(actions, action_dim)`
94+
- These helpers:
95+
- **Pad** to `action_dim` (if dimension is smaller).
96+
- **Clip abnormal values** outside \([-π, π]\) to 0 (for robustness, consistent with `FakeInputs` logic).
97+
- Accumulates processed arrays into:
98+
- `collected_data["state"]`
99+
- `collected_data["actions"]`
100+
- Maintains a running `total_frames` counter and respects `max_frames` if provided.
101+
102+
5. **Concatenate and pad**
103+
- Concatenates all collected batches per key:
104+
- `all_data["state"]`, `all_data["actions"]`
105+
- Ensures the last dimension matches `action_dim` (pads with zeros if needed).
106+
107+
6. **Compute statistics with `RunningStats`**
108+
- Initializes one `normalize.RunningStats()` per key (`state`, `actions`).
109+
- Feeds data in **batches of 32** to match the original implementation’s floating-point
110+
accumulation behavior.
111+
- For each key, computes:
112+
- `mean`, `std`, `q01`, `q99`, etc.
113+
114+
7. **Save `norm_stats`**
115+
- Collects results into a dict `norm_stats`.
116+
- Saves them with `openpi.shared.normalize.save` to:
117+
- `output_path = config.assets_dirs / data_config.repo_id`
118+
- Prints the output path and a success message:
119+
- `✅ Normalization stats saved to <output_path>`
120+
121+
> **Note:** The save logic mirrors the original openpi `compute_norm_stats.py` behavior so that
122+
> training code can load `norm_stats` transparently.
123+
124+
---
125+
126+
### Typical workflow with Kai0 configs
127+
128+
1. **Download dataset**
129+
- Follow [`docs/dataset.md`](./dataset.md#step-1-download-the-dataset) to download the Kai0
130+
dataset under `./data` at the repo root.
131+
132+
2. **Set config paths**
133+
- Edit `src/openpi/training/config.py` for the normal π₀.5 configs (see README `Preparation`):
134+
- `repo_id` → absolute path to the dataset subset, e.g.
135+
`<path_to_repo_root>/data/FlattenFold/base`
136+
- `weight_loader` → path to the π₀.5 base checkpoint (e.g. Kai0 best model per task).
137+
138+
3. **Compute normalization stats**
139+
- From the repo root:
140+
141+
```bash
142+
uv run python scripts/compute_norm_states_fast.py --config-name pi05_flatten_fold_normal
58143
```
59144

60-
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
145+
4. **Train**
146+
- Then run JAX training with:
61147

62-
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
148+
```bash
149+
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 \
150+
uv run scripts/train.py pi05_flatten_fold_normal --exp_name=<your_experiment_name>
151+
```
63152

64-
General info for Pi robots:
65-
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
66-
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
67-
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
153+
The training code will pick up the normalization statistics saved by this script and use them
154+
for input normalization, in the same way as the original openpi pipeline.
68155

69-
For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.

scripts/download_checkpoints.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Download Kai0 best-model checkpoints from Hugging Face to the repo's ./checkpoints directory.
4+
5+
Run from the repository root:
6+
python scripts/download_checkpoints.py
7+
8+
Optional: download only specific tasks or set a custom output path:
9+
python scripts/download_checkpoints.py --tasks FlattenFold HangCloth --local-dir ./my_ckpts
10+
"""
11+
from __future__ import annotations
12+
13+
import argparse
14+
import sys
15+
from multiprocessing import Process
16+
from pathlib import Path
17+
18+
19+
def get_repo_root() -> Path:
20+
"""Return the repository root (directory containing .git)."""
21+
path = Path(__file__).resolve().parent.parent
22+
if (path / ".git").exists():
23+
return path
24+
# Fallback: assume cwd is repo root
25+
return Path.cwd()
26+
27+
28+
def main() -> int:
29+
parser = argparse.ArgumentParser(
30+
description="Download Kai0 best-model checkpoints from Hugging Face to ./checkpoints (or --local-dir)."
31+
)
32+
parser.add_argument(
33+
"--local-dir",
34+
type=str,
35+
default=None,
36+
help="Directory to save checkpoints (default: <repo_root>/checkpoints)",
37+
)
38+
parser.add_argument(
39+
"--tasks",
40+
nargs="+",
41+
choices=["FlattenFold", "HangCloth", "TeeShirtSort"],
42+
default=None,
43+
help="Download only these task folders from the repo (default: all)",
44+
)
45+
parser.add_argument(
46+
"--repo-id",
47+
type=str,
48+
default="OpenDriveLab-org/Kai0",
49+
help="Hugging Face repo id that hosts best-model checkpoints (default: OpenDriveLab-org/Kai0)",
50+
)
51+
args = parser.parse_args()
52+
53+
try:
54+
from huggingface_hub import snapshot_download # type: ignore
55+
except ImportError:
56+
print("Install huggingface_hub first: pip install huggingface_hub", file=sys.stderr)
57+
return 1
58+
59+
repo_root = get_repo_root()
60+
local_dir = Path(args.local_dir) if args.local_dir else repo_root / "checkpoints"
61+
local_dir = local_dir.resolve()
62+
63+
allow_patterns = None
64+
if args.tasks:
65+
# Each task corresponds to a top-level folder in the repo.
66+
allow_patterns = [f"{t}/*" for t in args.tasks]
67+
allow_patterns.append("README.md")
68+
69+
print(f"Downloading checkpoints to {local_dir}")
70+
print(f"Repo: {args.repo_id}" + (f", tasks: {args.tasks}" if args.tasks else " (all tasks)"))
71+
72+
# Run snapshot_download in a separate process so Ctrl+C in the main process
73+
# can reliably terminate the download, even if the library swallows signals.
74+
def _worker():
75+
snapshot_download(
76+
repo_id=args.repo_id,
77+
repo_type="model",
78+
local_dir=str(local_dir),
79+
local_dir_use_symlinks=False,
80+
allow_patterns=allow_patterns,
81+
)
82+
83+
proc = Process(target=_worker)
84+
proc.start()
85+
86+
try:
87+
proc.join()
88+
except KeyboardInterrupt:
89+
print(
90+
"\nCheckpoint download interrupted by user (Ctrl+C). Terminating download process...",
91+
file=sys.stderr,
92+
)
93+
proc.terminate()
94+
proc.join()
95+
print("Partial checkpoint data may remain in:", local_dir, file=sys.stderr)
96+
return 130
97+
98+
if proc.exitcode != 0:
99+
print(f"\nCheckpoint download process exited with code {proc.exitcode}", file=sys.stderr)
100+
return proc.exitcode or 1
101+
102+
print(f"\nDone. Checkpoints are at: {local_dir}")
103+
return 0
104+
105+
106+
if __name__ == "__main__":
107+
sys.exit(main())
108+

0 commit comments

Comments
 (0)