Skip to content

Commit 65a27b7

Browse files
committed
Merge branch 'main' of github.com:modelscope/Trinity-RFT into dev/add_sync_v2
2 parents a689262 + 8a1d316 commit 65a27b7

34 files changed

+1150
-475
lines changed

.github/workflows/unittest.yaml

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@ permissions:
1212
jobs:
1313
unittest:
1414
# only run on pull request
15-
if: ${{ github.event.issue.pull_request && startsWith(github.event.comment.body, '/run-unittest') && github.event.comment.author_association == 'COLLABORATOR' }}
15+
if: ${{ github.event.issue.pull_request && (startsWith(github.event.comment.body, '/unittest')) && github.event.comment.author_association == 'COLLABORATOR' }}
1616
runs-on: self-hosted
1717

1818
steps:
1919
- uses: actions/checkout@v4
2020
with:
21+
fetch-depth: 0
2122
path: trinity-${{ github.run_id }}
2223
ref: refs/pull/${{ github.event.issue.number }}/head
2324

@@ -33,18 +34,70 @@ jobs:
3334
docker compose exec trinity-node-1 ray status
3435
docker compose exec trinity-node-2 ray status
3536
37+
- name: Decide test type
38+
id: test_type
39+
working-directory: trinity-${{ github.run_id }}
40+
run: |
41+
COMMENT="${{ github.event.comment.body }}"
42+
if [[ "$COMMENT" == "/unittest-all"* ]]; then
43+
echo "type=all" >> $GITHUB_OUTPUT
44+
elif [[ "$COMMENT" == "/unittest-diff"* ]]; then
45+
echo "type=diff" >> $GITHUB_OUTPUT
46+
elif [[ "$COMMENT" =~ ^/unittest-module-(.+)$ ]]; then
47+
MODULE=$(echo "$COMMENT" | sed -n 's/\/unittest-module-\(.*\)/\1/p')
48+
echo "type=module" >> $GITHUB_OUTPUT
49+
echo "module=$MODULE" >> $GITHUB_OUTPUT
50+
else
51+
echo "type=all" >> $GITHUB_OUTPUT
52+
fi
53+
54+
- name: Get changed modules (for diff)
55+
if: steps.test_type.outputs.type == 'diff'
56+
id: diff
57+
working-directory: trinity-${{ github.run_id }}
58+
run: |
59+
git fetch origin main
60+
git diff --name-only origin/main...HEAD > changed_files.txt
61+
awk -F/ '/^(trinity)\// {print $2}' changed_files.txt | sort | uniq > changed_modules.txt
62+
awk '{print "tests/"$1}' changed_modules.txt > test_dirs.txt
63+
3664
- name: Run unittest
3765
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
3866
run: |
39-
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
67+
TYPE="${{ steps.test_type.outputs.type }}"
68+
if [ "$TYPE" = "all" ]; then
69+
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
70+
echo "tests_run=true" >> $GITHUB_ENV
71+
elif [ "$TYPE" = "diff" ]; then
72+
ROOT_DIR=trinity-${{ github.run_id }}
73+
if [ -s "$ROOT_DIR/test_dirs.txt" ]; then
74+
TEST_DIRS=$(cat "$ROOT_DIR/test_dirs.txt" | xargs)
75+
docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json
76+
echo "tests_run=true" >> $GITHUB_ENV
77+
else
78+
echo "No changed modules detected, skipping tests."
79+
echo "tests_run=false" >> $GITHUB_ENV
80+
fi
81+
elif [ "$TYPE" = "module" ]; then
82+
MODULE="${{ steps.test_type.outputs.module }}"
83+
if [ -n "$MODULE" ]; then
84+
docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json
85+
echo "tests_run=true" >> $GITHUB_ENV
86+
else
87+
echo "No module specified, skipping tests."
88+
echo "tests_run=false" >> $GITHUB_ENV
89+
fi
90+
fi
4091
4192
- name: Upload test results
93+
if: env.tests_run == 'true'
4294
uses: actions/upload-artifact@v4
4395
with:
4496
name: pytest-results
4597
path: trinity-${{ github.run_id }}/report.json
4698

4799
- name: Publish Test Report
100+
if: env.tests_run == 'true'
48101
uses: ctrf-io/github-test-reporter@v1
49102
with:
50103
report-path: trinity-${{ github.run_id }}/report.json
@@ -53,7 +106,6 @@ jobs:
53106
issue: ${{ github.event.issue.number }}
54107
env:
55108
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
56-
if: always()
57109

58110
- name: Remove docker compose
59111
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ It is designed to support diverse application scenarios and serve as a unified p
152152
### Step 1: installation
153153

154154

155+
Requirements:
156+
- Python version >= 3.10, <= 3.12
157+
- CUDA version >= 12.4, <= 12.8
158+
- At least 2 GPUs
159+
160+
155161
Installation from source **(recommended)**:
156162

157163
```shell
@@ -181,13 +187,15 @@ pip install -e .[flash_attn]
181187
# for zsh
182188
pip install -e .\[flash_attn\]
183189
# Try the following command if you encounter errors during flash-attn installation
184-
# pip install flash-attn -v --no-build-isolation
190+
# pip install flash-attn==2.8.0.post2 -v --no-build-isolation
185191
```
186192

187193
Installation using pip:
188194

189195
```shell
190196
pip install trinity-rft==0.2.0
197+
# install flash-attn separately
198+
pip install flash-attn==2.8.0.post2
191199
```
192200

193201
Installation from docker:
@@ -206,13 +214,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
206214
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
207215
```
208216

209-
210-
**Requirements:**
211-
Python version >= 3.10,
212-
CUDA version >= 12.4,
213-
and at least 2 GPUs.
214-
215-
216217
### Step 2: prepare dataset and model
217218

218219

README_zh.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ Trinity-RFT是一个通用、灵活且易于使用的大语言模型强化微调
151151

152152
### 第一步:安装
153153

154+
环境要求:
155+
- Python >= 3.10, <= 3.12
156+
- CUDA >= 12.4, <= 12.8
157+
- 至少 2 块 GPU
158+
154159

155160
源码安装 **(推荐)**
156161

@@ -181,13 +186,15 @@ pip install -e .[flash_attn]
181186
# 适用于 zsh
182187
pip install -e .\[flash_attn\]
183188
# 如果安装 flash-attn 时遇到错误,可以尝试以下命令
184-
# pip install flash-attn -v --no-build-isolation
189+
# pip install flash-attn==2.8.0.post2 -v --no-build-isolation
185190
```
186191

187192
使用 pip 安装:
188193

189194
```shell
190195
pip install trinity-rft==0.2.0
196+
# flash-attn 需要单独安装
197+
pip install flash-attn==2.8.0.post2
191198
```
192199

193200
使用 Docker 安装:
@@ -207,12 +214,6 @@ docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path
207214
```
208215

209216

210-
**环境要求:**
211-
Python 版本 >= 3.10,
212-
CUDA 版本 >= 12.4,
213-
以及至少 2 块 GPU。
214-
215-
216217
### 第二步:准备数据集和模型
217218

218219

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to
142142
cumsum = torch.cumsum(attention_mask, dim=-1)
143143
position_ids = torch.clip(cumsum - 1, 0, None).long()
144144
batch_dict = {
145-
"uid": np.array(experiences.group_ids),
145+
"uid": np.array([eid.tid for eid in experiences.eids]),
146+
"unique_ids": np.array([eid.uid for eid in experiences.eids]),
146147
"position_ids": position_ids,
147148
"input_ids": experiences.tokens.long(),
148149
"responses": experiences.tokens[:, experiences.prompt_length :].long(),

docs/sphinx_doc/source/tutorial/faq.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ File ".../flash_attn/flash_attn_interface.py", line 15, in ‹module>
6565
ImportError: ...
6666
```
6767

68-
**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn` or `pip install flash-attn -v --no-build-isolation`.
68+
**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn==2.8.0.post2` or `pip install flash-attn==2.8.0.post2 -v --no-build-isolation`.
6969

7070
---
7171

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ algorithm:
9090
kl_penalty_fn: "none"
9191
kl_loss_fn: "k2"
9292
entropy_loss_fn: "default"
93+
add_strategy: null
9394
```
9495
9596
- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`, `sft`, `mix`.
@@ -99,7 +100,7 @@ algorithm:
99100
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
100101
- `kl_loss_fn`: The KL loss function used for computing KL loss.
101102
- `entropy_loss_fn`: The entropy loss function used for computing entropy loss.
102-
103+
- `add_strategy`: Strategy for adding new experiences to the experience buffer. If set, explorer will collect experiences from workflow runners and pre-process them before adding to the buffer.
103104

104105
---
105106

tests/buffer/queue_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TestQueueBuffer(RayUnittestBaseAysnc):
3030
)
3131
async def test_queue_buffer(self, name, use_priority_queue):
3232
meta = StorageConfig(
33-
name="test_buffer",
33+
name=name,
3434
algorithm_type="ppo",
3535
storage_type=StorageType.QUEUE,
3636
max_read_timeout=3,
@@ -60,7 +60,6 @@ async def test_queue_buffer(self, name, use_priority_queue):
6060
exps = [
6161
Experience(
6262
tokens=torch.tensor([float(j) for j in range(i + 1)]),
63-
prompt_length=i,
6463
reward=float(i),
6564
logprobs=torch.tensor([0.1]),
6665
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),

tests/buffer/sql_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ async def test_create_sql_buffer(self) -> None:
3838
prompt_length=i,
3939
reward=float(i),
4040
logprobs=torch.tensor([0.1]),
41-
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
4241
)
4342
for i in range(1, put_batch_size + 1)
4443
]
@@ -54,7 +53,6 @@ async def test_create_sql_buffer(self) -> None:
5453
[
5554
Experience(
5655
tokens=torch.tensor([float(j) for j in range(i + 1)]),
57-
prompt_length=i,
5856
reward=float(i),
5957
logprobs=torch.tensor([0.1]),
6058
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),

0 commit comments

Comments
 (0)