Skip to content

Commit 24fe03f

Browse files
authored
feat: support sft for embodiment (RLinf#436)
* feat: support openpi sft by libero data and custom data Signed-off-by: xusi <xusiforwork@gmail.com>
1 parent 497961c commit 24fe03f

File tree

29 files changed

+1542
-14
lines changed

29 files changed

+1542
-14
lines changed

.github/workflows/ci-tests.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ jobs:
9595
if: needs.check-changes.outputs.file_filter == 'true' || needs.check-changes.outputs.install_filter == 'true'
9696
uses: ./.github/workflows/embodied-e2e-tests.yml
9797

98+
# =============================================== sft e2e tests ====================================================
99+
100+
sft-e2e-tests:
101+
needs: [check-changes]
102+
if: needs.check-changes.outputs.file_filter == 'true' || needs.check-changes.outputs.install_filter == 'true'
103+
uses: ./.github/workflows/sft-e2e-tests.yml
104+
98105
# =============================================== scheduler tests ====================================================
99106

100107
scheduler-tests:
@@ -112,7 +119,8 @@ jobs:
112119
unit-tests,
113120
agent-reason-e2e-tests,
114121
embodied-e2e-tests,
115-
scheduler-tests
122+
scheduler-tests,
123+
sft-e2e-tests
116124
]
117125
if: always()
118126
runs-on: ubuntu-latest
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: SFT End-to-End Tests
2+
3+
on:
4+
workflow_call:
5+
6+
jobs:
7+
sft-maniskill-openpi-test:
8+
runs-on: embodied
9+
steps:
10+
- name: Checkout code
11+
uses: actions/checkout@v5
12+
13+
- name: Create sft environment
14+
run: |
15+
unset UV_DEFAULT_INDEX
16+
export UV_PATH=/workspace/dataset/.uv
17+
export UV_LINK_MODE=symlink
18+
export UV_CACHE_DIR=/workspace/dataset/.uv_cache
19+
export UV_PYTHON_INSTALL_DIR=/workspace/dataset/.uv_python
20+
export LIBERO_PATH=/workspace/dataset/LIBERO
21+
bash requirements/install.sh embodied --model openpi --env maniskill_libero
22+
23+
- name: SFT ManiSkill OpenPI test
24+
timeout-minutes: 20
25+
run: |
26+
export REPO_PATH=$(pwd)
27+
source .venv/bin/activate
28+
bash tests/e2e_tests/sft/run.sh maniskill_sft_openpi
29+
30+
- name: Clean up
31+
run: |
32+
rm -rf .venv
33+
uv cache prune

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ RLinf is a flexible and scalable open-source infrastructure designed for post-tr
114114
</ul>
115115
<li><b>SFT</b></li>
116116
<ul>
117-
<li>Full-parameter SFT</li>
118-
<li>LoRA SFT</li>
117+
<li><a href="https://rlinf.readthedocs.io/en/latest/rst_source/examples/fine_tine.html">Full-parameter SFT</a> ✅</li>
118+
<li><a href="https://rlinf.readthedocs.io/en/latest/rst_source/examples/fine_tine.html">LoRA SFT</a> ✅</li>
119119
</ul>
120120
</ul>
121121
</td>

README.zh-CN.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ RLinf 是一个灵活且可扩展的开源框架,专为利用强化学习进
113113
</ul>
114114
<li><b>SFT</b></li>
115115
<ul>
116-
<li>全量微调</li>
117-
<li>LoRA微调</li>
116+
<li><a href="https://rlinf.readthedocs.io/zh-cn/latest/rst_source/examples/fine_tine.html">全量微调</a> ✅</li>
117+
<li><a href="https://rlinf.readthedocs.io/zh-cn/latest/rst_source/examples/fine_tine.html">LoRA微调</a> ✅</li>
118118
</ul>
119119
</ul>
120120
</td>

docs/source-en/rst_source/examples/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,4 @@ Thanks to this decoupled design, workers can be flexibly and dynamically schedul
257257
gr00t
258258
reasoning
259259
coding_online_rl
260+
sft
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
Supervised Fine-Tuning
2+
=======================
3+
4+
.. |huggingface| image:: /_static/svg/hf-logo.svg
5+
:width: 16px
6+
:height: 16px
7+
:class: inline-icon
8+
9+
This page explains how to run **full-parameter supervised fine-tuning (SFT)** and **LoRA fine-tuning** with the RLinf framework. SFT is typically the first stage before reinforcement learning: the model imitates high-quality examples so RL can continue optimization with a strong prior.
10+
11+
Contents
12+
----------
13+
14+
- How to configure full-parameter SFT and LoRA SFT in RLinf
15+
- How to launch training on a single machine or multi-node cluster
16+
- How to monitor and evaluate results
17+
18+
19+
Supported datasets
20+
--------------------
21+
22+
RLinf currently supports datasets in the LeRobot format, selected via **config_type**.
23+
24+
Supported formats include:
25+
26+
- pi0_maniskill
27+
- pi0_libero
28+
- pi05_libero
29+
- pi05_maniskill
30+
- pi05_metaworld
31+
- pi05_calvin
32+
33+
You can also train with a custom dataset format. Refer to the files below:
34+
35+
1. In ``examples/sft/config/custom_sft_openpi.yaml``, set the data format.
36+
37+
.. code:: yaml
38+
39+
model:
40+
openpi:
41+
config_name: "pi0_custom"
42+
43+
2. In ``rlinf/models/embodiment/openpi/__init__.py``, set the data format to ``pi0_custom``.
44+
45+
.. code:: python
46+
47+
TrainConfig(
48+
name="pi0_custom",
49+
model=pi0_config.Pi0Config(),
50+
data=CustomDataConfig(
51+
repo_id="physical-intelligence/custom_dataset",
52+
base_config=DataConfig(
53+
prompt_from_task=True
54+
), # we need language instruction
55+
assets=AssetsConfig(assets_dir="checkpoints/torch/pi0_base/assets"),
56+
extra_delta_transform=True, # True for delta action, False for abs_action
57+
action_train_with_rotation_6d=False, # User can add extra config in custom dataset
58+
),
59+
pytorch_weight_path="checkpoints/torch/pi0_base",
60+
),
61+
62+
3. In ``rlinf/models/embodiment/openpi/dataconfig/custom_dataconfig.py``, define the custom dataset config.
63+
64+
.. code:: python
65+
66+
class CustomDataConfig(DataConfig):
67+
def __init__(self, *args, **kwargs):
68+
super().__init__(*args, **kwargs)
69+
self.repo_id = "physical-intelligence/custom_dataset"
70+
self.base_config = DataConfig(
71+
prompt_from_task=True
72+
)
73+
self.assets = AssetsConfig(assets_dir="checkpoints/torch/pi0_base/assets")
74+
self.extra_delta_transform = True
75+
self.action_train_with_rotation_6d = False
76+
77+
78+
Training configuration
79+
----------------------
80+
81+
A full example lives in ``examples/sft/config/libero_sft_openpi.yaml``. Key fields:
82+
83+
.. code:: yaml
84+
85+
cluster:
86+
num_nodes: 1 # number of nodes
87+
component_placement: # component → GPU mapping
88+
actor: 0-3
89+
90+
To enable LoRA fine-tuning, set ``actor.model.is_lora`` to True and configure ``actor.model.lora_rank``.
91+
92+
.. code:: yaml
93+
94+
actor:
95+
model:
96+
is_lora: True
97+
lora_rank: 32
98+
99+
Launch scripts
100+
----------------
101+
102+
First start the Ray cluster, then run the helper script:
103+
104+
.. code:: bash
105+
106+
cd /path_to_RLinf/ray_utils
107+
bash start_ray.sh # start head + workers
108+
109+
# return to repo root
110+
bash examples/sft/train_embodied_sft.py --config libero_sft_openpi.yaml
111+
112+
The same script works for generic text SFT; just swap the config file.

docs/source-zh/rst_source/examples/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,4 @@ RLinf的整体设计简洁且模块化,以Worker为抽象封装强化学习训
252252
gr00t
253253
reasoning
254254
coding_online_rl
255+
sft
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
监督微调训练
2+
=======================
3+
4+
.. |huggingface| image:: /_static/svg/hf-logo.svg
5+
:width: 16px
6+
:height: 16px
7+
:class: inline-icon
8+
9+
本文档介绍如何在 RLinf 框架中进行 **全量监督微调(Full-parameter SFT)** 和 **LoRA 微调**。SFT 通常作为进入强化学习前的第一阶段:模型先模仿高质量示例,后续强化学习才能在良好先验上继续优化。
10+
11+
内容包括
12+
--------
13+
14+
- 如何在 RLinf 中配置通用全量监督微调 和 LoRA微调
15+
- 如何在单机或多节点集群上启动训练
16+
- 如何监控与评估结果
17+
18+
19+
支持的数据集
20+
------------------
21+
22+
RLinf 目前支持 LeRobot 格式的数据集,可以通过 **config_type** 指定不同的数据集类型。
23+
24+
目前支持的数据格式包括:
25+
26+
- pi0_maniskill
27+
- pi0_libero
28+
- pi05_libero
29+
- pi05_maniskill
30+
- pi05_metaworld
31+
- pi05_calvin
32+
33+
也可通过自定义数据集格式来训练特定数据集,具体可参考以下文件
34+
35+
1. 在``examples/sft/config/custom_sft_openpi.yaml``中,指定数据格。
36+
37+
.. code:: yaml
38+
39+
model:
40+
openpi:
41+
config_name: "pi0_custom"
42+
43+
2. 在``rlinf/models/embodiment/openpi/__init__.py``中,指定数据格式为 ``pi0_custom``。
44+
45+
.. code:: python
46+
47+
TrainConfig(
48+
name="pi0_custom",
49+
model=pi0_config.Pi0Config(),
50+
data=CustomDataConfig(
51+
repo_id="physical-intelligence/custom_dataset",
52+
base_config=DataConfig(
53+
prompt_from_task=True
54+
), # we need language instruction
55+
assets=AssetsConfig(assets_dir="checkpoints/torch/pi0_base/assets"),
56+
extra_delta_transform=True, # True for delta action, False for abs_action
57+
action_train_with_rotation_6d=False, # User can add extra config in custom dataset
58+
),
59+
pytorch_weight_path="checkpoints/torch/pi0_base",
60+
),
61+
62+
3. 在``rlinf/models/embodiment/openpi/dataconfig/custom_dataconfig.py``中,定义自定义数据集的配置。
63+
64+
.. code:: python
65+
66+
class CustomDataConfig(DataConfig):
67+
def __init__(self, *args, **kwargs):
68+
super().__init__(*args, **kwargs)
69+
self.repo_id = "physical-intelligence/custom_dataset"
70+
self.base_config = DataConfig(
71+
prompt_from_task=True
72+
)
73+
self.assets = AssetsConfig(assets_dir="checkpoints/torch/pi0_base/assets")
74+
self.extra_delta_transform = True
75+
self.action_train_with_rotation_6d = False
76+
77+
78+
训练配置
79+
-------------
80+
81+
完整示例配置位于 ``examples/sft/config/libero_sft_openpi.yaml``,核心字段如下:
82+
83+
.. code:: yaml
84+
85+
cluster:
86+
num_nodes: 1 # 节点数
87+
component_placement: # 组件 → GPU 映射
88+
actor: 0-3
89+
90+
若需要支持LoRA微调,需要将``actor.model.is_lora``设置为True,并配置``actor.model.lora_rank``参数。
91+
92+
.. code:: yaml
93+
94+
actor:
95+
model:
96+
is_lora: True
97+
lora_rank: 32
98+
99+
启动脚本
100+
-------------
101+
102+
先启动 Ray 集群,然后执行辅助脚本:
103+
104+
.. code:: bash
105+
106+
cd /path_to_RLinf/ray_utils
107+
bash start_ray.sh # 启动 head + workers
108+
109+
# 回到仓库根目录
110+
bash examples/sft/train_embodied_sft.py --config libero_sft_openpi.yaml
111+
112+
同一脚本也适用于通用文本 SFT,只需替换配置文件。
113+
114+
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
defaults:
2+
- model/pi0@actor.model
3+
- training_backend/fsdp@actor.fsdp_config
4+
- override hydra/job_logging: stdout
5+
6+
hydra:
7+
run:
8+
dir: .
9+
output_subdir: null
10+
searchpath:
11+
- file://${oc.env:EMBODIED_PATH}/config/
12+
13+
cluster:
14+
num_nodes: 1
15+
component_placement:
16+
actor,env,rollout: 0-0
17+
18+
runner:
19+
task_type: sft
20+
logger:
21+
log_path: "../results"
22+
project_name: rlinf
23+
experiment_name: "test_openpi"
24+
logger_backends: ["tensorboard"] # wandb, swanlab
25+
26+
max_epochs: 1000
27+
max_steps: -1
28+
val_check_interval: -1
29+
save_interval: 10
30+
31+
data:
32+
data_path: "/path/to/custom-data"
33+
34+
algorithm:
35+
adv_type: gae
36+
37+
actor:
38+
group_name: "ActorGroup"
39+
training_backend: "fsdp"
40+
micro_batch_size: 1
41+
global_batch_size: 16
42+
seed: 0
43+
44+
# Override the default values in model/pi0
45+
model:
46+
precision: null
47+
model_path: "/path/to/pi0-model"
48+
num_action_chunks: 4 # interface for the env
49+
add_value_head: True
50+
openpi:
51+
config_name: "pi0_custom"
52+
detach_critic_input: True
53+
54+
optim:
55+
lr: 7.91e-6
56+
value_lr: 1.55e-4
57+
adam_beta1: 0.9
58+
adam_beta2: 0.95
59+
adam_eps: 1.0e-05
60+
clip_grad: 1.0
61+
62+
# Override the default values in training_backend/fsdp
63+
fsdp_config:
64+
strategy: "fsdp"
65+
sharding_strategy: "no_shard"
66+
use_orig_params: True
67+
gradient_checkpointing: False # for openpi, gradient checkpointing is not supported, please do not change this value
68+
mixed_precision:
69+
param_dtype: ${actor.model.precision}
70+
reduce_dtype: ${actor.model.precision}
71+
buffer_dtype: ${actor.model.precision}
72+
73+
reward:
74+
use_reward_model: False
75+
76+
critic:
77+
use_critic_model: False

0 commit comments

Comments
 (0)