Skip to content

Commit b2282ae

Browse files
authored
[Example] Frozen_Lake (#375)
1 parent 980cffc commit b2282ae

File tree

18 files changed

+929
-14
lines changed

18 files changed

+929
-14
lines changed

docs/sphinx_doc/source/tutorial/example_reasoning_basic.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ Download the GSM8K dataset to the local directory `$DATASET_PATH/gsm8k`:
2929

3030
```bash
3131
# Using Modelscope
32-
modelscope download --dataset modelscope/gsm8k --local_dir $DATASET_PATH/gsm8k
32+
modelscope download --dataset AI-ModelScope/gsm8k --local_dir $DATASET_PATH/gsm8k
3333

3434
# Using Huggingface
3535
huggingface-cli download openai/gsm8k --repo-type dataset --local-dir $DATASET_PATH/gsm8k
3636
```
3737

3838
More details on dataset downloading are referred to [ModelScope](https://modelscope.cn/docs/datasets/download) or [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space).
39+
The dataset downloaded from ModelScope may lack the `dtype` field and cause error when loading the dataset. To solve this issue, please delete the `dataset_infos.json` file and run the experiment again.
3940

4041
## Step 2: Set up Configuration and Run Experiment
4142

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ model:
163163
max_prompt_tokens: 4096
164164
max_response_tokens: 16384
165165
min_response_tokens: 1
166+
enable_prompt_truncation: true
166167
```
167168

168169
- `model_path`: Path to the model being trained.
@@ -173,6 +174,7 @@ model:
173174
- `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`.
174175
- `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`.
175176
- `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`.
177+
- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`.
176178

177179
```{tip}
178180
If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API.

docs/sphinx_doc/source_zh/tutorial/example_reasoning_basic.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir $MODEL_PATH/Qwen
3030

3131
```bash
3232
# 使用 Modelscope
33-
modelscope download --dataset modelscope/gsm8k --local_dir $DATASET_PATH/gsm8k
33+
modelscope download --dataset AI-ModelScope/gsm8k --local_dir $DATASET_PATH/gsm8k
3434

3535
# 使用 Huggingface
3636
huggingface-cli download openai/gsm8k --repo-type dataset --local-dir $DATASET_PATH/gsm8k
3737
```
3838

3939
更多关于数据集下载的细节请参考 [ModelScope](https://modelscope.cn/docs/datasets/download)[Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space)
40+
从 ModelScope 下载的数据集可能缺少 `dtype` 字段,导致加载数据集时出错。要解决这个问题,请删除 `dataset_infos.json` 文件并重新运行实验。
4041

4142
## 第 2 步:配置实验并运行
4243

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ model:
163163
max_prompt_tokens: 4096
164164
max_response_tokens: 16384
165165
min_response_tokens: 1
166+
enable_prompt_truncation: true
166167
```
167168

168169
- `model_path`: 被训练模型的路径。
@@ -173,6 +174,7 @@ model:
173174
- `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
174175
- `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
175176
- `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。
177+
- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。
176178

177179
```{tip}
178180
如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Frozen Lake
2+
3+
This example shows the usage of GRPO on the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task. Note that this task is only tested with Qwen2.5 Instruct models.
4+
5+
6+
## Data and Environment Preparation
7+
8+
After setting up the basic environment following the [installation guidance](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html), you need to install the additional dependencies by running the following command:
9+
10+
```bash
11+
pip install gymnasium[toy_text]
12+
```
13+
14+
Then, we prepare the dataset by running the following command:
15+
16+
```bash
17+
cd examples/grpo_frozen_lake
18+
python get_frozen_lake_data.py
19+
```
20+
21+
This command will save the dataset to the local directory `/path/to/frozenlake`, and print the path of the dataset. Afterwards, make sure to set the environment variable `TRINITY_TASKSET_PATH` to the path of the dataset.
22+
```bash
23+
export TRINITY_TASKSET_PATH=/path/to/frozenlake
24+
```
25+
26+
27+
## Workflow Configuration and Training
28+
29+
We use a concatenated multi-turn workflow `FrozenLakeWorkflow` to solve the Frozen Lake task. For each rollout, the multi-turn interaction in between the agent and feedback from the environment are stored in a single `Experience` object.
30+
The specific configuration is located in [`frozen_lake.yaml`](frozen_lake.yaml).
31+
32+
To run this example, you can use the following command:
33+
34+
```bash
35+
trinity run --config examples/grpo_frozen_lake/frozen_lake.yaml
36+
```
37+
38+
## Results
39+
We show the result with a Qwen2.5-3B-Instruct model in the following. The figures demonstrate both the reward and the test score increase over training steps.
40+
41+
![reward](frozen_lake_reward.png)
42+
43+
![test_score](frozen_lake_test_score.png)
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
project: "FrozenLake"
2+
name: "trinity-frozen-lake"
3+
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
4+
algorithm:
5+
algorithm_type: grpo
6+
repeat_times: 8
7+
optimizer:
8+
lr: 1e-6
9+
policy_loss_fn_args:
10+
loss_agg_mode: "seq-mean-token-sum"
11+
clip_range_low: 0.2
12+
clip_range_high: 0.28
13+
kl_loss_fn_args:
14+
kl_coef: 0.0
15+
model:
16+
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct}
17+
enable_prompt_truncation: false
18+
max_response_tokens: 10240
19+
max_model_len: 14436
20+
temperature: 0.7
21+
cluster:
22+
node_num: 1
23+
gpu_per_node: 8
24+
buffer:
25+
total_epochs: 1
26+
batch_size: 64
27+
explorer_input:
28+
taskset:
29+
name: frozenlake
30+
storage_type: file
31+
path: ${oc.env:TRINITY_TASKSET_PATH}
32+
split: train
33+
workflow_args:
34+
env_max_steps: 8
35+
agent_max_steps: 10
36+
is_slippery: false
37+
eval_tasksets:
38+
- name: frozenlake
39+
storage_type: file
40+
path: ${oc.env:TRINITY_TASKSET_PATH}
41+
split: test
42+
workflow_args:
43+
env_max_steps: 8
44+
agent_max_steps: 10
45+
is_slippery: false
46+
rollout_args:
47+
n: 4
48+
top_p: 0.8
49+
top_k: 20
50+
default_workflow_type: 'frozen_lake_workflow'
51+
explorer:
52+
eval_on_startup: true
53+
eval_interval: 10
54+
runner_per_model: 8
55+
rollout_model:
56+
engine_num: 6
57+
tensor_parallel_size: 1
58+
enable_chunked_prefill: true
59+
enforce_eager: false
60+
dtype: bfloat16
61+
seed: 42
62+
gpu_memory_utilization: 0.85
63+
trainer:
64+
trainer_type: 'verl'
65+
save_interval: 1000
66+
use_dynamic_bsz: true
67+
max_token_len_per_gpu: 16384
68+
ulysses_sequence_parallel_size: 1
69+
trainer_config:
70+
actor_rollout_ref:
71+
hybrid_engine: true
72+
model:
73+
use_remove_padding: true
74+
enable_gradient_checkpointing: true
75+
actor:
76+
fsdp_config:
77+
param_offload: true
78+
optimizer_offload: true
79+
ref:
80+
fsdp_config:
81+
param_offload: true
82+
synchronizer:
83+
sync_method: nccl
84+
sync_interval: 1
85+
sync_timeout: 1200
522 KB
Loading
459 KB
Loading
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Modified from https://github.com/rllm-org/rllm/blob/main/examples/frozenlake/prepare_frozenlake_data.py
3+
"""
4+
import os
5+
6+
import numpy as np
7+
import pandas as pd
8+
9+
from trinity.common.constants import TASKSET_PATH_ENV_VAR
10+
11+
path_from_env = os.environ.get(TASKSET_PATH_ENV_VAR)
12+
if path_from_env is not None:
13+
DATA_ROOT_DIR = os.path.dirname(path_from_env)
14+
else:
15+
DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), "data")
16+
17+
18+
def save_dataset_to_local(name: str, data: list[dict], split: str = "default") -> str:
19+
"""Save dataset directly to local DATA_PATH.
20+
21+
Args:
22+
name: Name of the dataset
23+
data: List of dictionaries containing the dataset examples
24+
split: Split name (e.g., 'train', 'test', 'default')
25+
26+
Returns:
27+
str: Path to the saved parquet file
28+
"""
29+
dataset_dir = os.path.join(DATA_ROOT_DIR, name)
30+
os.makedirs(dataset_dir, exist_ok=True)
31+
32+
# Convert to DataFrame and save
33+
data_df = pd.DataFrame(data)
34+
dataset_path = os.path.join(dataset_dir, f"{split}.parquet")
35+
data_df.to_parquet(dataset_path)
36+
37+
print(
38+
f"Saved dataset '{name}' split '{split}' with {len(data)} examples at {dataset_path}. Make sure to set the environment variable {TASKSET_PATH_ENV_VAR} to {DATA_ROOT_DIR}/{name}."
39+
)
40+
41+
return dataset_path
42+
43+
44+
def prepare_frozenlake_data(train_size=10000, test_size=100, map_max_size=6):
45+
"""
46+
Prepare and save FrozenLake datasets for training and testing.
47+
48+
Args:
49+
train_size (int): Number of training examples to generate
50+
test_size (int): Number of test examples to generate
51+
52+
Returns:
53+
tuple: (train_data, test_data) - Lists of data dictionaries
54+
"""
55+
# Set random seed for reproducibility
56+
np.random.seed(42)
57+
58+
# Generate random parameters for train and test sets
59+
train_seeds = np.random.randint(0, 100000, size=train_size)
60+
test_seeds = np.random.randint(0, 100000, size=test_size)
61+
train_sizes = np.random.randint(2, map_max_size, size=train_size)
62+
test_sizes = np.random.randint(2, map_max_size, size=test_size)
63+
train_ps = np.random.uniform(0.6, 0.85, size=train_size)
64+
test_ps = np.random.uniform(0.6, 0.85, size=test_size)
65+
66+
def frozenlake_process_fn(seed, size, p, idx):
67+
"""Process function to create FrozenLake task instances."""
68+
return {"seed": seed, "size": size, "p": p, "index": idx, "uid": f"{seed}_{size}_{p}"}
69+
70+
# Create train and test data
71+
train_data = [
72+
frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx)
73+
for idx, seed in enumerate(train_seeds)
74+
]
75+
test_data = [
76+
frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx)
77+
for idx, seed in enumerate(test_seeds)
78+
]
79+
80+
# Save datasets directly to local DATA_PATH
81+
save_dataset_to_local("frozenlake", train_data, "train")
82+
save_dataset_to_local("frozenlake", test_data, "test")
83+
84+
return train_data, test_data
85+
86+
87+
if __name__ == "__main__":
88+
train_data, test_data = prepare_frozenlake_data()
89+
print(f"Train dataset: {len(train_data)} examples")
90+
print(f"Test dataset: {len(test_data)} examples")
91+
print("Sample train example:", train_data[0])
92+
print("Sample test example:", test_data[0])

tests/common/vllm_test.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def setUp(self):
234234
self.config.model.max_model_len = self.max_model_len
235235
self.config.model.max_prompt_tokens = self.max_prompt_tokens
236236
self.config.model.max_response_tokens = self.max_response_tokens
237+
self.config.model.enable_prompt_truncation = True
237238
self.config.explorer.rollout_model.enable_openai_api = True
238239
self.config.check_and_update()
239240

@@ -246,14 +247,21 @@ async def test_model_len(self):
246247
{"role": "system", "content": "You are a helpful assistant."},
247248
{"role": "user", "content": "What's the weather like today?"},
248249
]
250+
251+
# For vllm engine, max_prompt_tokens and max_response_tokens work
249252
response = self.model_wrapper.chat(messages)
250253
self.assertEqual(len(response), 1)
251-
self.assertEqual(len(response[0].tokens), self.max_model_len)
254+
self.assertEqual(len(response[0].tokens), self.config.model.max_model_len)
252255
exps = self.model_wrapper.extract_experience_from_history()
253256
self.assertEqual(len(exps), 1)
254-
self.assertEqual(len(exps[0].tokens), self.max_model_len)
257+
# check prompt length, response length, max_model_len
258+
self.assertEqual(exps[0].prompt_length, self.config.model.max_prompt_tokens)
259+
self.assertEqual(
260+
len(exps[0].tokens) - exps[0].prompt_length, self.config.model.max_response_tokens
261+
)
262+
self.assertLessEqual(len(response[0].tokens), self.config.model.max_model_len)
255263

256-
# max_prompt_tokens and max_response_tokens do not work with openai api
264+
# For openai api, max_prompt_tokens and max_response_tokens do not work
257265
openai_client = self.model_wrapper.get_openai_client()
258266
model_id = openai_client.models.list().data[0].id
259267
with self.assertRaises(BadRequestError):
@@ -267,9 +275,57 @@ async def test_model_len(self):
267275
exps = self.model_wrapper.extract_experience_from_history()
268276
self.assertEqual(len(exps), 1)
269277
# only generate max_response_tokens tokens
270-
self.assertEqual(
271-
len(exps[0].tokens),
272-
response.usage.prompt_tokens + self.config.model.max_response_tokens,
278+
self.assertLessEqual(
279+
len(exps[0].tokens) - response.usage.prompt_tokens,
280+
self.config.model.max_response_tokens,
281+
)
282+
283+
284+
class TestModelLenWithoutPromptTruncation(RayUnittestBaseAysnc):
285+
def setUp(self):
286+
self.config = get_template_config()
287+
self.config.mode = "explore"
288+
self.config.model.model_path = get_model_path()
289+
self.config.model.max_model_len = 20
290+
self.config.model.max_prompt_tokens = 1
291+
self.config.model.max_response_tokens = None
292+
self.config.model.enable_prompt_truncation = False
293+
self.config.explorer.rollout_model.enable_openai_api = True
294+
self.config.check_and_update()
295+
296+
self.engines, self.auxiliary_engines = create_inference_models(self.config)
297+
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)
298+
299+
async def test_model_len(self):
300+
await self.model_wrapper.prepare()
301+
messages = [
302+
{"role": "user", "content": "How are you?"},
303+
]
304+
305+
# For vllm engine, max_prompt_tokens and max_response_tokens work
306+
response = self.model_wrapper.chat(messages)
307+
self.assertEqual(len(response), 1)
308+
self.assertLessEqual(
309+
len(response[0].tokens) - response[0].prompt_length,
310+
self.config.model.max_response_tokens,
311+
)
312+
exps = self.model_wrapper.extract_experience_from_history()
313+
self.assertEqual(len(exps), 1)
314+
self.assertLessEqual(
315+
len(exps[0].tokens) - exps[0].prompt_length,
316+
self.config.model.max_response_tokens,
317+
)
318+
319+
# For openai api
320+
openai_client = self.model_wrapper.get_openai_client()
321+
model_id = openai_client.models.list().data[0].id
322+
response = openai_client.chat.completions.create(model=model_id, messages=messages, n=1)
323+
self.assertEqual(len(response.choices), 1)
324+
exps = self.model_wrapper.extract_experience_from_history()
325+
self.assertEqual(len(exps), 1)
326+
self.assertLessEqual(
327+
len(exps[0].tokens) - response.usage.prompt_tokens,
328+
self.config.model.max_response_tokens,
273329
)
274330

275331

0 commit comments

Comments
 (0)