diff --git a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md index e0d2024f4f..19e683dcee 100644 --- a/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source/tutorial/example_reasoning_basic.md @@ -29,13 +29,14 @@ Download the GSM8K dataset to the local directory `$DATASET_PATH/gsm8k`: ```bash # Using Modelscope -modelscope download --dataset modelscope/gsm8k --local_dir $DATASET_PATH/gsm8k +modelscope download --dataset AI-ModelScope/gsm8k --local_dir $DATASET_PATH/gsm8k # Using Huggingface huggingface-cli download openai/gsm8k --repo-type dataset --local-dir $DATASET_PATH/gsm8k ``` More details on dataset downloading are referred to [ModelScope](https://modelscope.cn/docs/datasets/download) or [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space). +The dataset downloaded from ModelScope may lack the `dtype` field and cause error when loading the dataset. To solve this issue, please delete the `dataset_infos.json` file and run the experiment again. ## Step 2: Set up Configuration and Run Experiment diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index f25738f546..e9cf4fcded 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -163,6 +163,7 @@ model: max_prompt_tokens: 4096 max_response_tokens: 16384 min_response_tokens: 1 + enable_prompt_truncation: true ``` - `model_path`: Path to the model being trained. @@ -173,6 +174,7 @@ model: - `max_response_tokens`: Maximum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. - `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. - `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`. +- `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. ```{tip} If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API. diff --git a/docs/sphinx_doc/source_zh/tutorial/example_reasoning_basic.md b/docs/sphinx_doc/source_zh/tutorial/example_reasoning_basic.md index 17129a6555..1a80338b78 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_reasoning_basic.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_reasoning_basic.md @@ -30,13 +30,14 @@ huggingface-cli download Qwen/Qwen2.5-1.5B-Instruct --local-dir $MODEL_PATH/Qwen ```bash # 使用 Modelscope -modelscope download --dataset modelscope/gsm8k --local_dir $DATASET_PATH/gsm8k +modelscope download --dataset AI-ModelScope/gsm8k --local_dir $DATASET_PATH/gsm8k # 使用 Huggingface huggingface-cli download openai/gsm8k --repo-type dataset --local-dir $DATASET_PATH/gsm8k ``` 更多关于数据集下载的细节请参考 [ModelScope](https://modelscope.cn/docs/datasets/download) 或 [Huggingface](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli#download-a-dataset-or-a-space)。 +从 ModelScope 下载的数据集可能缺少 `dtype` 字段,导致加载数据集时出错。要解决这个问题,请删除 `dataset_infos.json` 文件并重新运行实验。 ## 第 2 步:配置实验并运行 diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index c6dc344007..7227423385 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -163,6 +163,7 @@ model: max_prompt_tokens: 4096 max_response_tokens: 16384 min_response_tokens: 1 + enable_prompt_truncation: true ``` - `model_path`: 被训练模型的路径。 @@ -173,6 +174,7 @@ model: - `max_prompt_tokens`: 输入 prompt 中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 +- `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。 ```{tip} 如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。 diff --git a/examples/grpo_frozen_lake/README.md b/examples/grpo_frozen_lake/README.md new file mode 100644 index 0000000000..ec79861848 --- /dev/null +++ b/examples/grpo_frozen_lake/README.md @@ -0,0 +1,43 @@ +# Frozen Lake + +This example shows the usage of GRPO on the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task. Note that this task is only tested with Qwen2.5 Instruct models. + + +## Data and Environment Preparation + +After setting up the basic environment following the [installation guidance](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html), you need to install the additional dependencies by running the following command: + +```bash +pip install gymnasium[toy_text] +``` + +Then, we prepare the dataset by running the following command: + +```bash +cd examples/grpo_frozen_lake +python get_frozen_lake_data.py +``` + +This command will save the dataset to the local directory `/path/to/frozenlake`, and print the path of the dataset. Afterwards, make sure to set the environment variable `TRINITY_TASKSET_PATH` to the path of the dataset. +```bash +export TRINITY_TASKSET_PATH=/path/to/frozenlake +``` + + +## Workflow Configuration and Training + +We use a concatenated multi-turn workflow `FrozenLakeWorkflow` to solve the Frozen Lake task. For each rollout, the multi-turn interaction in between the agent and feedback from the environment are stored in a single `Experience` object. +The specific configuration is located in [`frozen_lake.yaml`](frozen_lake.yaml). + +To run this example, you can use the following command: + +```bash +trinity run --config examples/grpo_frozen_lake/frozen_lake.yaml +``` + +## Results +We show the result with a Qwen2.5-3B-Instruct model in the following. The figures demonstrate both the reward and the test score increase over training steps. + +![reward](frozen_lake_reward.png) + +![test_score](frozen_lake_test_score.png) diff --git a/examples/grpo_frozen_lake/frozen_lake.yaml b/examples/grpo_frozen_lake/frozen_lake.yaml new file mode 100644 index 0000000000..648bab71bc --- /dev/null +++ b/examples/grpo_frozen_lake/frozen_lake.yaml @@ -0,0 +1,85 @@ +project: "FrozenLake" +name: "trinity-frozen-lake" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + optimizer: + lr: 1e-6 + policy_loss_fn_args: + loss_agg_mode: "seq-mean-token-sum" + clip_range_low: 0.2 + clip_range_high: 0.28 + kl_loss_fn_args: + kl_coef: 0.0 +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} + enable_prompt_truncation: false + max_response_tokens: 10240 + max_model_len: 14436 + temperature: 0.7 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 64 + explorer_input: + taskset: + name: frozenlake + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} + split: train + workflow_args: + env_max_steps: 8 + agent_max_steps: 10 + is_slippery: false + eval_tasksets: + - name: frozenlake + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} + split: test + workflow_args: + env_max_steps: 8 + agent_max_steps: 10 + is_slippery: false + rollout_args: + n: 4 + top_p: 0.8 + top_k: 20 + default_workflow_type: 'frozen_lake_workflow' +explorer: + eval_on_startup: true + eval_interval: 10 + runner_per_model: 8 + rollout_model: + engine_num: 6 + tensor_parallel_size: 1 + enable_chunked_prefill: true + enforce_eager: false + dtype: bfloat16 + seed: 42 + gpu_memory_utilization: 0.85 +trainer: + trainer_type: 'verl' + save_interval: 1000 + use_dynamic_bsz: true + max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 1 + trainer_config: + actor_rollout_ref: + hybrid_engine: true + model: + use_remove_padding: true + enable_gradient_checkpointing: true + actor: + fsdp_config: + param_offload: true + optimizer_offload: true + ref: + fsdp_config: + param_offload: true +synchronizer: + sync_method: nccl + sync_interval: 1 + sync_timeout: 1200 diff --git a/examples/grpo_frozen_lake/frozen_lake_reward.png b/examples/grpo_frozen_lake/frozen_lake_reward.png new file mode 100644 index 0000000000..4927a2702c Binary files /dev/null and b/examples/grpo_frozen_lake/frozen_lake_reward.png differ diff --git a/examples/grpo_frozen_lake/frozen_lake_test_score.png b/examples/grpo_frozen_lake/frozen_lake_test_score.png new file mode 100644 index 0000000000..a7dc5cd555 Binary files /dev/null and b/examples/grpo_frozen_lake/frozen_lake_test_score.png differ diff --git a/examples/grpo_frozen_lake/get_frozen_lake_data.py b/examples/grpo_frozen_lake/get_frozen_lake_data.py new file mode 100644 index 0000000000..17b4aae87b --- /dev/null +++ b/examples/grpo_frozen_lake/get_frozen_lake_data.py @@ -0,0 +1,92 @@ +""" +Modified from https://github.com/rllm-org/rllm/blob/main/examples/frozenlake/prepare_frozenlake_data.py +""" +import os + +import numpy as np +import pandas as pd + +from trinity.common.constants import TASKSET_PATH_ENV_VAR + +path_from_env = os.environ.get(TASKSET_PATH_ENV_VAR) +if path_from_env is not None: + DATA_ROOT_DIR = os.path.dirname(path_from_env) +else: + DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), "data") + + +def save_dataset_to_local(name: str, data: list[dict], split: str = "default") -> str: + """Save dataset directly to local DATA_PATH. + + Args: + name: Name of the dataset + data: List of dictionaries containing the dataset examples + split: Split name (e.g., 'train', 'test', 'default') + + Returns: + str: Path to the saved parquet file + """ + dataset_dir = os.path.join(DATA_ROOT_DIR, name) + os.makedirs(dataset_dir, exist_ok=True) + + # Convert to DataFrame and save + data_df = pd.DataFrame(data) + dataset_path = os.path.join(dataset_dir, f"{split}.parquet") + data_df.to_parquet(dataset_path) + + print( + f"Saved dataset '{name}' split '{split}' with {len(data)} examples at {dataset_path}. Make sure to set the environment variable {TASKSET_PATH_ENV_VAR} to {DATA_ROOT_DIR}/{name}." + ) + + return dataset_path + + +def prepare_frozenlake_data(train_size=10000, test_size=100, map_max_size=6): + """ + Prepare and save FrozenLake datasets for training and testing. + + Args: + train_size (int): Number of training examples to generate + test_size (int): Number of test examples to generate + + Returns: + tuple: (train_data, test_data) - Lists of data dictionaries + """ + # Set random seed for reproducibility + np.random.seed(42) + + # Generate random parameters for train and test sets + train_seeds = np.random.randint(0, 100000, size=train_size) + test_seeds = np.random.randint(0, 100000, size=test_size) + train_sizes = np.random.randint(2, map_max_size, size=train_size) + test_sizes = np.random.randint(2, map_max_size, size=test_size) + train_ps = np.random.uniform(0.6, 0.85, size=train_size) + test_ps = np.random.uniform(0.6, 0.85, size=test_size) + + def frozenlake_process_fn(seed, size, p, idx): + """Process function to create FrozenLake task instances.""" + return {"seed": seed, "size": size, "p": p, "index": idx, "uid": f"{seed}_{size}_{p}"} + + # Create train and test data + train_data = [ + frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx) + for idx, seed in enumerate(train_seeds) + ] + test_data = [ + frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx) + for idx, seed in enumerate(test_seeds) + ] + + # Save datasets directly to local DATA_PATH + save_dataset_to_local("frozenlake", train_data, "train") + save_dataset_to_local("frozenlake", test_data, "test") + + return train_data, test_data + + +if __name__ == "__main__": + train_data, test_data = prepare_frozenlake_data() + print(f"Train dataset: {len(train_data)} examples") + print(f"Test dataset: {len(test_data)} examples") + print("Sample train example:", train_data[0]) + print("Sample test example:", test_data[0]) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index b578cf190d..2f281f4b32 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -234,6 +234,7 @@ def setUp(self): self.config.model.max_model_len = self.max_model_len self.config.model.max_prompt_tokens = self.max_prompt_tokens self.config.model.max_response_tokens = self.max_response_tokens + self.config.model.enable_prompt_truncation = True self.config.explorer.rollout_model.enable_openai_api = True self.config.check_and_update() @@ -246,14 +247,21 @@ async def test_model_len(self): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What's the weather like today?"}, ] + + # For vllm engine, max_prompt_tokens and max_response_tokens work response = self.model_wrapper.chat(messages) self.assertEqual(len(response), 1) - self.assertEqual(len(response[0].tokens), self.max_model_len) + self.assertEqual(len(response[0].tokens), self.config.model.max_model_len) exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 1) - self.assertEqual(len(exps[0].tokens), self.max_model_len) + # check prompt length, response length, max_model_len + self.assertEqual(exps[0].prompt_length, self.config.model.max_prompt_tokens) + self.assertEqual( + len(exps[0].tokens) - exps[0].prompt_length, self.config.model.max_response_tokens + ) + self.assertLessEqual(len(response[0].tokens), self.config.model.max_model_len) - # max_prompt_tokens and max_response_tokens do not work with openai api + # For openai api, max_prompt_tokens and max_response_tokens do not work openai_client = self.model_wrapper.get_openai_client() model_id = openai_client.models.list().data[0].id with self.assertRaises(BadRequestError): @@ -267,9 +275,57 @@ async def test_model_len(self): exps = self.model_wrapper.extract_experience_from_history() self.assertEqual(len(exps), 1) # only generate max_response_tokens tokens - self.assertEqual( - len(exps[0].tokens), - response.usage.prompt_tokens + self.config.model.max_response_tokens, + self.assertLessEqual( + len(exps[0].tokens) - response.usage.prompt_tokens, + self.config.model.max_response_tokens, + ) + + +class TestModelLenWithoutPromptTruncation(RayUnittestBaseAysnc): + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.model.max_model_len = 20 + self.config.model.max_prompt_tokens = 1 + self.config.model.max_response_tokens = None + self.config.model.enable_prompt_truncation = False + self.config.explorer.rollout_model.enable_openai_api = True + self.config.check_and_update() + + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + + async def test_model_len(self): + await self.model_wrapper.prepare() + messages = [ + {"role": "user", "content": "How are you?"}, + ] + + # For vllm engine, max_prompt_tokens and max_response_tokens work + response = self.model_wrapper.chat(messages) + self.assertEqual(len(response), 1) + self.assertLessEqual( + len(response[0].tokens) - response[0].prompt_length, + self.config.model.max_response_tokens, + ) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertLessEqual( + len(exps[0].tokens) - exps[0].prompt_length, + self.config.model.max_response_tokens, + ) + + # For openai api + openai_client = self.model_wrapper.get_openai_client() + model_id = openai_client.models.list().data[0].id + response = openai_client.chat.completions.create(model=model_id, messages=messages, n=1) + self.assertEqual(len(response.choices), 1) + exps = self.model_wrapper.extract_experience_from_history() + self.assertEqual(len(exps), 1) + self.assertLessEqual( + len(exps[0].tokens) - response.usage.prompt_tokens, + self.config.model.max_response_tokens, ) diff --git a/trinity/common/config.py b/trinity/common/config.py index bd2d6f6907..5262b8d94a 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -457,6 +457,8 @@ class ModelConfig: max_response_tokens: Optional[int] = None # the minimum number of tokens for the response min_response_tokens: int = 1 + # whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens. + enable_prompt_truncation: bool = True # lora config lora_configs: Optional[List[LoRAConfig]] = None @@ -498,6 +500,8 @@ class InferenceModelConfig: max_response_tokens: Optional[int] = None # if not set, use `model.min_response_tokens` min_response_tokens: Optional[int] = None + # if not set, use `model.enable_prompt_truncation` + enable_prompt_truncation: Optional[bool] = None # used for testing very long response generation, do not set it unless you know what you are doing ignore_eos: bool = False @@ -1121,9 +1125,14 @@ def _check_model(self) -> None: model.critic_model_path = model.model_path # check template - if model.chat_template_path and model.custom_chat_template is None: - with open(model.chat_template_path, "r") as f: - model.custom_chat_template = f.read() + if model.chat_template_path is not None and model.custom_chat_template is None: + try: + with open(model.chat_template_path, "r") as f: + model.custom_chat_template = f.read() + except Exception as e: + raise ValueError( + f"Failed to read chat template from {model.chat_template_path}: {e}" + ) # check max_model_len, max_prompt_tokens, max_response_tokens @@ -1178,6 +1187,19 @@ def _check_model(self) -> None: model.min_response_tokens = max(model.max_response_tokens - 1, 0) # type: ignore [operator] logger.warning(f"`min_response_tokens` is set to {model.min_response_tokens}.") + if model.enable_prompt_truncation is True: + if model.max_prompt_tokens is None: + raise ValueError( + "When `model.enable_prompt_truncation` is True, `model.max_prompt_tokens` must be set properly." + ) + logger.warning( + f"`enable_prompt_truncation` is set to True; the prompt will be truncated to `max_prompt_tokens`={model.max_prompt_tokens} tokens if it is too long." + ) + else: + logger.warning( + "`enable_prompt_truncation` is set to False; please make sure the prompt is not too long and `max_model_len` is large enough, otherwise prompt length + response length may exceed `max_model_len`!" + ) + def __iter__(self): """Iterate over configs with each stage applied in order. @@ -1248,6 +1270,7 @@ def check_and_update(self) -> Config: # noqa: C901 "max_prompt_tokens", "max_response_tokens", "min_response_tokens", + "enable_prompt_truncation", ] rope_args = ["rope_scaling", "rope_theta"] model_args = rollout_args + length_args + rope_args diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index a2ecdd90a4..5ed1c1689c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -276,6 +276,9 @@ async def get_lora_request_async(self) -> Optional[LoRARequest]: else: return None + async def get_message_token_len(self, messages: List[dict]) -> int: + return await self.model.get_message_token_len.remote(messages) + def get_openai_client(self) -> openai.OpenAI: """Get the openai client. diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 518233a57f..aefafb315b 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -65,11 +65,15 @@ def __init__( temperature=config.temperature, max_tokens=config.max_response_tokens, min_tokens=config.min_response_tokens, - truncate_prompt_tokens=config.max_prompt_tokens, + truncate_prompt_tokens=( + config.max_prompt_tokens if config.enable_prompt_truncation else None + ), skip_special_tokens=True, include_stop_str_in_output=False, output_kind=RequestOutputKind.FINAL_ONLY, logprobs=config.logprobs, + top_p=config.top_p, + top_k=config.top_k, ignore_eos=config.ignore_eos, ) self.enable_thinking = config.enable_thinking @@ -188,7 +192,10 @@ async def generate( if self.tokenizer is None: await self._initialize_tokenizer() token_ids = self.tokenizer( # type: ignore - prompt, truncation=True, max_length=self.config.max_prompt_tokens, return_tensors="pt" + prompt, + truncation=self.config.enable_prompt_truncation, + max_length=self.config.max_prompt_tokens, + return_tensors="pt", )["input_ids"][0].tolist() output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, lora_request=lora_request, **kwargs @@ -387,6 +394,19 @@ async def convert_messages_to_experience( chat_template=self.chat_template, enable_thinking=self.enable_thinking, ) # (seq_length, ), (seq_length, ) + + # Truncate tokens if they exceed the length limit + assert token_ids is not None + is_truncated = False # TODO: add to experience itself + if self.config.max_model_len is not None and self.config.max_model_len > 0: + if len(token_ids) > self.config.max_model_len - 1: + is_truncated = True + self.logger.warning( + f"Warning: {len(token_ids) = } exceeds the length limit {self.config.max_model_len-1 = }" + ) + token_ids = token_ids[: self.config.max_model_len - 1] + action_mask = action_mask[: self.config.max_model_len - 1] + temperature = temperature if temperature is not None else self.config.temperature logprobs = await self.logprobs( token_ids=token_ids.tolist(), temperature=temperature @@ -397,6 +417,7 @@ async def convert_messages_to_experience( prompt_length=prompt_length, action_mask=action_mask[prompt_length:], # Exclude the prompt tokens messages=messages, + info={"is_truncated": is_truncated}, ) async def shutdown(self): @@ -549,6 +570,23 @@ def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest: lora_request.lora_path = lora_path return lora_request + async def get_message_token_len(self, messages) -> int: + if self.tokenizer is None: + await self._initialize_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=self.chat_template, + enable_thinking=self.enable_thinking, + ) + prompt_token = self.tokenizer( # type: ignore + prompt, truncation=False, return_tensors="pt" + )["input_ids"][0].tolist() + return len(prompt_token) + async def sleep(self, level: int = 1) -> None: await self.async_llm.sleep(level=level) diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 1277f041f9..847a3d8707 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -29,6 +29,7 @@ RAFTReflectAlfworldWorkflow, ) from trinity.common.workflows.envs.email_searcher.workflow import EmailSearchWorkflow +from trinity.common.workflows.envs.frozen_lake.workflow import FrozenLakeWorkflow from trinity.common.workflows.envs.sciworld.sciworld_workflow import SciWorldWorkflow from trinity.common.workflows.envs.webshop.webshop_workflow import WebShopWorkflow from trinity.common.workflows.eval_workflow import ( @@ -94,4 +95,5 @@ "SimpleMMWorkflow", "RubricJudgeWorkflow", "AgentScopeWorkflowAdapter", + "FrozenLakeWorkflow", ] diff --git a/trinity/common/workflows/envs/frozen_lake/utils.py b/trinity/common/workflows/envs/frozen_lake/utils.py new file mode 100644 index 0000000000..b5d0f319fc --- /dev/null +++ b/trinity/common/workflows/envs/frozen_lake/utils.py @@ -0,0 +1,162 @@ +""" +Utils for the FrozenLake environment. +Modified from https://github.com/rllm-org/rllm/blob/main/rllm/environments/frozenlake/frozenlake.py +""" + +from typing import Optional, Tuple + +import numpy as np + +# Map gym state in integer +MAP_LOOKUP = { + b"P": 0, + b"F": 1, + b"H": 2, + b"G": 3, +} + +# Define rules to transform to rendered text observation of the environment +GRID_LOOKUP = { + 0: " P \t", # player + 1: " _ \t", # frozen + 2: " O \t", # hole + 3: " G \t", # goal + 4: " X \t", # player fall into hole + 5: " √ \t", # player on goal +} + +ACTION_LOOKUP = { + 0: "None", + 1: "Left", + 2: "Down", + 3: "Right", + 4: "Up", +} + +# Prompting format inspired by the RAGEN project: https://github.com/RAGEN-AI/RAGEN +SYSTEM_PROMPT = """You are Qwen, created by Alibaba Cloud. You are a helpful assistant. You are walking on a frozen lake. + +FrozenLake Quick Guide +Goal: Reach the goal (G). Player (P) and Goal (G) must overlap. + +Symbols: +_ Frozen | O Hole | G Goal | P Player + +Rules: +1. Avoid falling into holes (O). +2. Frozen tiles are slippery, you may move perpendicular to your intended direction. + +Valid Action (separated by | ): +Up | Down | Left | Right + +Rewards: +Fall into hole: 0 +Reach goal: +1.0 + +You will be provided the current observation, please decide on the next Action. +You should show your thought process and then input the final action in ``` ```. +You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```. +You should plan ahead and need to achieve it in minimum number of steps. +You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it. + +Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right. +""" + + +def is_valid(board: list[list[str]], max_size: int, max_steps: int) -> bool: + """DFS to check that it's a valid path. + + Args: + board: The board representation as a list of lists. + max_size: Maximum size of the board. + max_steps: Maximum number of steps allowed. + + Returns: + True if there's a valid path from start to goal within max_steps, False otherwise. + """ + frontier, discovered = [], set() + # find the start point + start_r, start_c = np.where(np.array(board) == "S") + frontier.append((start_r[0], start_c[0], 0)) # row, col steps + # dfs to check if there is a path from start to goal + while frontier: + r, c, steps = frontier.pop() + if steps > max_steps: + continue + + if (r, c) not in discovered: + discovered.add((r, c)) + directions = [(1, 0), (0, 1), (-1, 0), (0, -1)] + for x, y in directions: + r_new = r + x + c_new = c + y + if r_new < 0 or r_new >= max_size or c_new < 0 or c_new >= max_size: + continue + if board[r_new][c_new] == "G": + return True + if board[r_new][c_new] != "H": + frontier.append((r_new, c_new, steps + 1)) + return False + + +def generate_random_map( + size: int = 8, p: float = 0.8, seed: int = 0, max_steps: int = 5 +) -> Tuple[list[str], Tuple[int, int]]: + """Generates a random valid map (one that has a path from start to goal). + + Args: + size: Size of each side of the grid. + p: Probability that a tile is frozen. + seed: Seed to ensure the generation of reproducible maps. + max_steps: Maximum number of steps allowed. + + Returns: + A tuple containing a random valid map and the goal position (row, col). + """ + valid = False + board: list[list[str]] = [] # initialize to make pyright happy + + try: + from gymnasium.utils import seeding + + np_random, _ = seeding.np_random(seed) + except ImportError: + raise ImportError( + "Gymnasium is not installed. Please install gymnasium first before running the frozen_lake workflow." + ) + + # generate random start and end points + while not valid: + p = min(1, p) + board = np_random.choice(["F", "H"], (size, size), p=[p, 1 - p]).tolist() + + while True: + start_r = int(np_random.integers(0, size)) + start_c = int(np_random.integers(0, size)) + goal_r = int(np_random.integers(0, size)) + goal_c = int(np_random.integers(0, size)) + + # Ensure start and goal are different positions + if (start_r, start_c) != (goal_r, goal_c): + break + + board[start_r][start_c] = "S" + board[goal_r][goal_c] = "G" + + valid = is_valid(board, size, max_steps) + return ["".join(x) for x in board], (goal_r, goal_c) + + +def get_goal_position(random_map: np.ndarray) -> Optional[Tuple[int, int]]: + """Get the goal position from a random map. + + Args: + random_map: The map as a numpy array. + + Returns: + Tuple of (row, col) if goal found, None otherwise. + """ + positions = np.argwhere(random_map == b"G") + if positions.size == 0: + return None # G not found + return tuple(positions[0]) # returns (row, col) diff --git a/trinity/common/workflows/envs/frozen_lake/workflow.py b/trinity/common/workflows/envs/frozen_lake/workflow.py new file mode 100644 index 0000000000..35fc8bce98 --- /dev/null +++ b/trinity/common/workflows/envs/frozen_lake/workflow.py @@ -0,0 +1,395 @@ +# -*- coding: utf-8 -*- +""" +This file defines a multi-step workflow for the FrozenLake environment. +Modified from https://github.com/rllm-org/rllm/blob/main/rllm/environments/frozenlake/frozenlake.py +""" + +from __future__ import annotations + +import copy +import re +from dataclasses import asdict +from typing import List, Optional, Tuple + +import numpy as np + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.envs.frozen_lake.utils import ( + GRID_LOOKUP, + MAP_LOOKUP, + SYSTEM_PROMPT, + generate_random_map, + get_goal_position, +) +from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task + + +@WORKFLOWS.register_module("frozen_lake_workflow") +class FrozenLakeWorkflow(MultiTurnWorkflow): + """ + FrozenLake environment for multi-step workflows. + + ## Description + The game starts with the player at random location of the frozen lake grid world with the + goal located at another random location for the 4x4 environment. + + ## Action Space + The action shape is `(1,)` in the range `{0, 3}` indicating + which direction to move the player. + NOTE the action space is different from gymnasium.envs.toy_text.frozen_lake.FrozenLakeEnv, start from 1 + use action_map to map from custom action to action defined in FrozenLakeEnv in gymnasium + - 0: Still + - 1: Left + - 2: Down + - 3: Right + - 4: Up + + ## Starting State + The episode starts with the player at random location + + ## Rewards + Reward schedule: + - Reach goal: +1 + - Reach hole: 0 + - Reach frozen: 0 + + ## Arguments + `is_slippery`: if action is left and is_slippery is True, then: + - P(move left)=1/3 + - P(move up)=1/3 + - P(move down)=1/3 + + ## Example + P _ _ _ + _ _ _ O + O _ O _ + O _ _ G + """ + + can_reset: bool = False # GymFrozenLakeEnv can only reset the player position, not the environment configuration. + is_async: bool = True + can_repeat: bool = False + + def __init__( + self, + model: ModelWrapper, + task: Task, + auxiliary_models: Optional[List] = None, + ): + """Initialize the FrozenLake workflow. + + Args: + model: The model wrapper to use for generating actions. + task: The task configuration containing workflow-specific arguments. + auxiliary_models: Optional list of auxiliary models. + """ + super().__init__( + model=model, + task=task, + auxiliary_models=auxiliary_models, + ) + + # Import gymnasium here to avoid import error if gymnasium is not installed + # and this workflow is not used + try: + import gymnasium as gym + from gymnasium.envs.toy_text.frozen_lake import ( + FrozenLakeEnv as GymFrozenLakeEnv, + ) + except ImportError as e: + error_message = ( + f"Gymnasium is not installed. Please install gymnasium first before " + f"running the frozen_lake workflow. Error: {str(e)}" + ) + self.logger.error(error_message) + raise ImportError(error_message) + + # Extract workflow-specific arguments + workflow_args = task.workflow_args if hasattr(task, "workflow_args") else {} + self.env_max_steps = workflow_args.get("env_max_steps", 8) + self.agent_max_steps = workflow_args.get("agent_max_steps", 10) + self.desc = workflow_args.get("desc", None) + self.is_slippery = workflow_args.get("is_slippery", False) + self.max_response_tokens = self.rollout_args.get("max_response_tokens", 10240) + + # Extract task-specific arguments + self.raw_task = task.raw_task if hasattr(task, "raw_task") else {} + self.size = self.raw_task.get("size", 1) + self.p = self.raw_task.get("p", 0.8) + self.seed = self.raw_task.get("seed", 42) + + if self.desc is None: + random_map, goal_position = generate_random_map( + size=self.size, p=self.p, seed=self.seed, max_steps=self.env_max_steps + ) + else: + random_map = np.asarray(copy.deepcopy(self.desc), dtype="c") + goal_position = get_goal_position(random_map) + + self.goal_position = goal_position + + # Create the gym environment + self.gym_env = GymFrozenLakeEnv(desc=random_map[:], is_slippery=self.is_slippery) + self.action_space = gym.spaces.Discrete(4, start=1) + + # Define action map and invalid action + self.action_map = { + 1: 0, + 2: 1, + 3: 2, + 4: 3, + } # map from custom Env action to action defined in FrozenLakeEnv in gymnasium + self.invalid_action = 0 + + # Agent-related state + self.step_count: int = 0 + self.step_rewards: List[float] = [] + self.current_observation: Optional[str] = None + self.done: bool = False + self.last_observation: Optional[str] = None + + @property + def rollout_args(self): + return asdict(self.task.rollout_args) + + def _get_player_position(self) -> Tuple[int, int]: + """Get the current player position. + + Returns: + Tuple of (row, col) representing the player position. + """ + return ( + self.gym_env.s // self.gym_env.ncol, + self.gym_env.s % self.gym_env.ncol, + ) # (row, col) + + def finished(self) -> bool: + """Check if the episode is finished. + + Returns: + True if the player is on goal (G) or hole (H), False otherwise. + """ + player_pos = self._get_player_position() + return self.gym_env.desc[player_pos] in b"GH" + + def success(self) -> bool: + """Check if the agent has reached the goal (G). + + Returns: + True if the player is on goal (G), False otherwise. + """ + player_pos = self._get_player_position() + return self.gym_env.desc[player_pos] in b"G" + + def env_step(self, action: int): + """Execute a step in the environment. + + Maps custom action to gymnasium FrozenLakeEnv action and takes the step. + Checks if the action is effective (whether player moves in the env). + + Args: + action: The action to take. + + Returns: + Tuple of (observation, reward, done, info). + """ + if self.success(): + return self.render(), 1, True, {"action_is_effective": False} + + if not action: + action = self.invalid_action + + action = int(action) + if action == self.invalid_action or action not in self.action_map: + return self.render(), 0, False, {"action_is_effective": False} + + prev_player_position = int(self.gym_env.s) + + player_pos, reward, done, _, prob = self.gym_env.step(self.action_map[action]) + + obs = self.render() + return obs, reward, done, {"action_is_effective": prev_player_position != int(player_pos)} + + def render(self, mode="tiny_rgb_array"): + """Render the environment. + + Args: + mode: Rendering mode. Options: "tiny_rgb_array", "list", "state", "rgb_array", "ansi". + + Returns: + Rendered observation based on the mode. + """ + assert mode in ["tiny_rgb_array", "list", "state", "rgb_array", "ansi"] + if mode in ["rgb_array", "ansi"]: + prev_render_mode = self.gym_env.render_mode + self.gym_env.render_mode = mode + obs = self.gym_env.render() + self.gym_env.render_mode = prev_render_mode + return obs + room_state = copy.deepcopy(self.gym_env.desc) + + # replace the position of start 'S' with 'F' + position_S = np.where(room_state == b"S") + room_state[position_S] = b"F" + + # replace the position of the player with 'P' + position_P = self._get_player_position() + room_state[position_P] = b"P" + + if mode == "state": + # transform 'S', 'F', 'H', 'G' to numpy integer array + room_state = np.vectorize(lambda x: MAP_LOOKUP[x])(room_state) + # add player in hole or player on goal + if self.gym_env.desc[position_P] == b"H": + room_state[position_P] = 4 + elif self.gym_env.desc[position_P] == b"G": + room_state[position_P] = 5 + return room_state + + room_state = self.render(mode="state").tolist() + + if mode == "list": + + def lookup(cell): + return GRID_LOOKUP.get(cell, "?").strip("\t").strip() + + return [" ".join(lookup(cell) for cell in row) for row in room_state] + + if mode == "tiny_rgb_array": + + def lookup(cell): + return GRID_LOOKUP.get(cell, "?") + + result = "\n".join("".join(lookup(cell) for cell in row) for row in room_state) + return result + + async def run_async(self) -> List[Experience]: + """Run the workflow and return a list of experiences. + + Returns: + List of Experience objects, one for each rollout. + """ + # Reset environment and state for a new episode + # But this only resets the player position, not the environment configuration. + self.gym_env.reset(seed=self.seed) + observation = self.render() + self.current_observation = str(observation) + self.done = False + self.step_rewards = [] + self.step_count = 0 + self.action = None + terminate_reason = None + + # Initialize messages + messages = [] + system_prompt = SYSTEM_PROMPT + messages.append({"role": "system", "content": system_prompt}) + + # Run episode until done or max_steps reached + for step in range(self.agent_max_steps): + # Format observation for the model + user_prompt_content = ( + f"Current Observation ({self.step_count}): \n" + + self.current_observation + + "\n" + + "You have not achieved the goal, P has not reached G yet. Please give the next action." + ) + + if self.step_count > 0 and self.action is not None: + if self.last_observation == self.current_observation: + user_prompt_content += "\nYour last response is invalid. Your position didn't change at all. You may need to recheck your thinking process, action outputted, and the format of response. Remember, you should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```." + + if self.agent_max_steps is not None and self.agent_max_steps - self.step_count > 0: + user_prompt_content += f"\nThe maximum number of steps remaining is {self.agent_max_steps - self.step_count}." + + messages.append({"role": "user", "content": user_prompt_content}) + + messages_token_len = await self.model.get_message_token_len(messages) + if step == 0: + max_tokens = self.max_response_tokens + init_prompt_token_len = messages_token_len + else: + response_token_len = messages_token_len - init_prompt_token_len + max_tokens = self.max_response_tokens - response_token_len + + if max_tokens <= 0: + messages = messages[:-1] # Remove the last user message + self.done = False + self.step_rewards.append(0) + terminate_reason = "max_tokens_reached" + break + + # Get action from the model + rollout_args = self.rollout_args.copy() + rollout_args["n"] = 1 + rollout_args["max_tokens"] = max_tokens + responses = await self.model.chat_async(messages, **rollout_args) + response_text = responses[0].response_text + messages.append({"role": "assistant", "content": response_text}) + + # Parse action from response + _, action_str = self._parse_model_response(response_text) + action = int(action_str) if action_str.isdigit() else self.invalid_action + self.action = action + + # Execute action in the environment + observation, reward, done, info = self.env_step(action) + + # Update internal state + self.last_observation = self.current_observation + self.current_observation = str(observation) + self.done = done + self.step_rewards.append(reward) + self.step_count += 1 + + if self.done: + terminate_reason = "success" + break + if terminate_reason is None: + terminate_reason = "max_steps_reached" + + # Create experience from messages + final_reward = sum(self.step_rewards) + # print(f"final_reward: {final_reward}, terminate_reason: {terminate_reason}") + experience = self.process_messages_to_experience( + messages=messages, + reward=final_reward, + info={ + "env_steps": self.step_count, + "env_done": 1 if self.done else 0, + "test_score": final_reward, + }, + ) + return [experience] + + def _parse_model_response(self, response: str) -> tuple[str, str]: + """Parse the model response to extract thought and action. + + Args: + response: The model's response text. + + Returns: + Tuple of (thought, action_str). + """ + DIRECTION_MAP = {"left": 1, "down": 2, "right": 3, "up": 4} + + thought = response + action_str = str(self.invalid_action) + + matches = re.findall(r"```(.*?)```", response, re.DOTALL) + + if matches: + last_match_content = matches[-1].strip() + last_match_index = response.rfind(f"```{last_match_content}```") + if last_match_index != -1: + thought = response[:last_match_index].strip() + + extracted_text = last_match_content.lower() + + if extracted_text in DIRECTION_MAP: + action_str = str(DIRECTION_MAP[extracted_text]) + elif extracted_text.isdigit() and int(extracted_text) in DIRECTION_MAP.values(): + action_str = str(int(extracted_text)) + + return thought, action_str diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 91716e1688..0798c3e65a 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -168,6 +168,9 @@ def set_repeat_times(self, repeat_times, run_id_base): def process_messages_to_experience(self, messages, reward, info={}) -> Experience: converted_experience = self.model.convert_messages_to_experience(messages) + if converted_experience.info.get("is_truncated", False): + reward = 0.0 + tokens = converted_experience.tokens log_probs = converted_experience.logprobs assert converted_experience.action_mask is not None @@ -182,6 +185,9 @@ def process_messages_to_experience(self, messages, reward, info={}) -> Experienc experience = Experience( tokens=tokens, action_mask=generation_mask, + prompt_length=converted_experience.prompt_length, + prompt_text=converted_experience.prompt_text, + response_text=converted_experience.response_text, reward=reward, logprobs=log_probs, info=info, diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index a2c11ff49f..73cbeaa96e 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -451,8 +451,12 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 metrics.update(prefix_metrics(kl_metrics, prefix="critic")) # compute advantages, executed on the driver process batch, _ = self.advantage_fn(batch) + else: + # skip token_level_scores for sft/dpo + if "token_level_scores" in batch.batch.keys(): + batch.batch["token_level_scores"] = batch.batch["token_level_scores"] - # update critic + # update critic if self.algorithm.use_critic: with marked_timer("update_critic", timing_raw): critic_output = self.critic_wg.update_critic(batch)