Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
path: trinity-${{ github.run_id }}
ref: refs/pull/${{ github.event.issue.number }}/head

Expand Down
17 changes: 9 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ It is designed to support diverse application scenarios and serve as a unified p
### Step 1: installation


Requirements:
- Python version >= 3.10, <= 3.12
- CUDA version >= 12.4, <= 12.8
- At least 2 GPUs


Installation from source **(recommended)**:

```shell
Expand Down Expand Up @@ -181,13 +187,15 @@ pip install -e .[flash_attn]
# for zsh
pip install -e .\[flash_attn\]
# Try the following command if you encounter errors during flash-attn installation
# pip install flash-attn -v --no-build-isolation
# pip install flash-attn==2.8.0.post2 -v --no-build-isolation
```

Installation using pip:

```shell
pip install trinity-rft==0.2.0
# install flash-attn separately
pip install flash-attn==2.8.0.post2
```

Installation from docker:
Expand All @@ -206,13 +214,6 @@ docker build -f scripts/docker/Dockerfile -t trinity-rft:latest .
docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path_of_data_and_checkpoints>:/data trinity-rft:latest
```


**Requirements:**
Python version >= 3.10,
CUDA version >= 12.4,
and at least 2 GPUs.


### Step 2: prepare dataset and model


Expand Down
15 changes: 8 additions & 7 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ Trinity-RFT是一个通用、灵活且易于使用的大语言模型强化微调

### 第一步:安装

环境要求:
- Python >= 3.10, <= 3.12
- CUDA >= 12.4, <= 12.8
- 至少 2 块 GPU


源码安装 **(推荐)**:

Expand Down Expand Up @@ -181,13 +186,15 @@ pip install -e .[flash_attn]
# 适用于 zsh
pip install -e .\[flash_attn\]
# 如果安装 flash-attn 时遇到错误,可以尝试以下命令
# pip install flash-attn -v --no-build-isolation
# pip install flash-attn==2.8.0.post2 -v --no-build-isolation
```

使用 pip 安装:

```shell
pip install trinity-rft==0.2.0
# flash-attn 需要单独安装
pip install flash-attn==2.8.0.post2
```

使用 Docker 安装:
Expand All @@ -207,12 +214,6 @@ docker run -it --gpus all --shm-size="64g" --rm -v $PWD:/workspace -v <root_path
```


**环境要求:**
Python 版本 >= 3.10,
CUDA 版本 >= 12.4,
以及至少 2 块 GPU。


### 第二步:准备数据集和模型


Expand Down
3 changes: 2 additions & 1 deletion docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.group_ids),
"uid": np.array([eid.tid for eid in experiences.eids]),
"unique_ids": np.array([eid.uid for eid in experiences.eids]),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
Expand Down
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ File ".../flash_attn/flash_attn_interface.py", line 15, in ‹module>
ImportError: ...
```

**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn` or `pip install flash-attn -v --no-build-isolation`.
**A:** The `flash-attn` module is not properly installed. Try to fix it by running `pip install flash-attn==2.8.0.post2` or `pip install flash-attn==2.8.0.post2 -v --no-build-isolation`.

---

Expand Down
3 changes: 2 additions & 1 deletion docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ algorithm:
kl_penalty_fn: "none"
kl_loss_fn: "k2"
entropy_loss_fn: "default"
add_strategy: null
```

- `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`, `sft`, `mix`.
Expand All @@ -99,7 +100,7 @@ algorithm:
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
- `kl_loss_fn`: The KL loss function used for computing KL loss.
- `entropy_loss_fn`: The entropy loss function used for computing entropy loss.

- `add_strategy`: Strategy for adding new experiences to the experience buffer. If set, explorer will collect experiences from workflow runners and pre-process them before adding to the buffer.

---

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ requires-python = ">=3.10"
dependencies = [
"verl==0.4.1",
"ray[default]>=2.45.0",
"vllm==0.9.2",
"vllm>=0.9.1",
"tensordict==0.6.2",
"wandb",
"omegaconf",
Expand Down
3 changes: 1 addition & 2 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class TestQueueBuffer(RayUnittestBaseAysnc):
)
async def test_queue_buffer(self, name, use_priority_queue):
meta = StorageConfig(
name="test_buffer",
name=name,
algorithm_type="ppo",
storage_type=StorageType.QUEUE,
max_read_timeout=3,
Expand Down Expand Up @@ -60,7 +60,6 @@ async def test_queue_buffer(self, name, use_priority_queue):
exps = [
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
Expand Down
2 changes: 0 additions & 2 deletions tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ async def test_create_sql_buffer(self) -> None:
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
)
for i in range(1, put_batch_size + 1)
]
Expand All @@ -54,7 +53,6 @@ async def test_create_sql_buffer(self) -> None:
[
Experience(
tokens=torch.tensor([float(j) for j in range(i + 1)]),
prompt_length=i,
reward=float(i),
logprobs=torch.tensor([0.1]),
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
Expand Down
Loading