diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index f60d39ccf9..75d353fd4a 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1,3 +1,4 @@ +import os import unittest import ray @@ -935,3 +936,59 @@ async def test_api_tool_calls(self): print_debug( "\n" + "=" * 28 + f" test_api_tool_calls PASSED in {total_time:.2f}s " + "=" * 28 + "\n" ) + + +class TestSuperLongGeneration(RayUnittestBaseAysnc): + def setUp(self): + self.config = get_template_config() + self.config.mode = "explore" + self.config.model.model_path = get_model_path() + self.config.model.max_model_len = 81920 + self.config.model.max_prompt_tokens = 61440 + self.config.model.max_response_tokens = 20480 + self.config.model.rope_scaling = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 40960, + } + self.config.explorer.rollout_model.engine_type = "vllm" + self.config.explorer.rollout_model.engine_num = 1 + self.config.explorer.rollout_model.tensor_parallel_size = 1 + self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE + + self.config.check_and_update() + self.engines, self.auxiliary_engines = create_inference_models(self.config) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) + + async def test_generate(self): + base_dir = os.path.dirname(__file__) + target_dir = os.path.join(base_dir, "..", "..", "trinity", "trainer", "verl") + with open(os.path.join(target_dir, "fsdp_workers.py")) as f: + fsdp_code = f.read() + with open(os.path.join(target_dir, "megatron_workers.py")) as f: + megatron_code = f.read() + target_dir = os.path.join(base_dir, "..", "..", "trinity", "common") + with open(os.path.join(target_dir, "config.py")) as f: + config_code = f.read() + target_dir = os.path.join(base_dir, "..", "..", "trinity", "manager") + with open(os.path.join(target_dir, "config_manager.py")) as f: + config_manager_code = f.read() + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": """# Please add comments and documentation for these following code, """ + """make sure the code is well-structured and easy to read, """ + """and the complete code must be shown, do not omit any parts.\n""" + f"""## fsdp_workers.py\n{fsdp_code}\n""" + f"""## megatron_workers.py\n{megatron_code}\n""" + f"""## config.py\n{config_code}\n""" + f"""## config_manager.py\n{config_manager_code}\n""", + }, + ] + response = self.model_wrapper.chat(messages, n=1, temperature=0.7, logprobs=True)[0] + self.assertGreater( + response.prompt_length, 40960 + ) # If not long enough, please add more files to prompt + self.assertGreater(response.logprobs.shape[0], 1000) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index c99b7d1dda..a74dfeb7dc 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -73,6 +73,12 @@ class TestTrainerCountdown(BaseTrainerCase): def test_trainer(self): """Test the both and bench mode.""" # test both mode + self.config.model.rope_scaling = { + "rope_type": "yarn", + "factor": 2.0, + "original_max_position_embeddings": 16384, + } + self.config.model.rope_theta = 10000 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown") self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig( selector_type="shuffle", seed=42 diff --git a/trinity/common/config.py b/trinity/common/config.py index c722959b96..6ac31d45d9 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -437,6 +437,10 @@ class ModelConfig: fully_sharded_loras: bool = False max_cpu_loras: Optional[int] = None + # rope config + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + @dataclass class InferenceModelConfig: @@ -498,6 +502,10 @@ class InferenceModelConfig: lora_modules: Optional[List[Dict]] = None lora_kwargs: Optional[dict] = field(default_factory=dict) + # ! DO NOT SET, rope config + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + @dataclass class AlgorithmConfig: @@ -1190,12 +1198,14 @@ def check_and_update(self) -> Config: # noqa: C901 "max_response_tokens", "min_response_tokens", ] - for args in ["model_path"] + rollout_args + length_args: + rope_args = ["rope_scaling", "rope_theta"] + model_args = rollout_args + length_args + rope_args + for args in ["model_path"] + model_args: setattr(self.explorer.rollout_model, args, getattr(self.model, args)) for aux_model in self.explorer.auxiliary_models: if not aux_model.model_path: raise ValueError("auxiliary model's model_path is required.") - for args in rollout_args + length_args: + for args in model_args: set_if_none(aux_model, args, getattr(self.model, args)) # for lora configs diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index d2d3b25c68..5d1b32716a 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -77,6 +77,11 @@ def __init__( max_model_len = config.max_model_len self.enable_lora = config.enable_lora self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None) + rope_kwargs = { + key: getattr(config, key) + for key in ["rope_scaling", "rope_theta"] + if getattr(config, key) is not None + } engine_args = vllm.AsyncEngineArgs( model=config.model_path, enforce_eager=config.enforce_eager, @@ -101,6 +106,7 @@ def __init__( disable_log_stats=True, enable_lora=config.enable_lora, logprobs_mode="processed_logprobs", + **rope_kwargs, **config.lora_kwargs, ) if get_vllm_version() > parse_version("0.10.0"): diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 49a241393a..e8ca94484a 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -40,6 +40,10 @@ class ActorModel: lora_alpha: int = 32 target_modules: Optional[str] = "all-linear" + # rope configs + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + @dataclass class Optim: @@ -412,6 +416,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 # Actor / Rollout Config self.actor_rollout_ref.model.path = config.model.model_path self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template + self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling + self.actor_rollout_ref.model.rope_theta = config.model.rope_theta self.actor_rollout_ref.actor.optim.total_training_steps = self.trainer.total_training_steps self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size self.actor_rollout_ref.rollout.temperature = ( diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 899e991432..fedc54bb55 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -257,6 +257,12 @@ def _build_model_optimizer( # noqa: C901 local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2" ) + # patch for rope + if self.config.model.rope_scaling is not None: + actor_model_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling) + if self.config.model.rope_theta is not None: + actor_model_config.rope_theta = self.config.model.rope_theta + # patch for kimi-vl if getattr(actor_model_config, "model_type", None) == "kimi_vl": actor_model_config.text_config.topk_method = "greedy" diff --git a/trinity/trainer/verl/megatron_workers.py b/trinity/trainer/verl/megatron_workers.py index a186d799ea..0735a477e5 100644 --- a/trinity/trainer/verl/megatron_workers.py +++ b/trinity/trainer/verl/megatron_workers.py @@ -151,6 +151,82 @@ def __init__(self, config: DictConfig, role: str, **kwargs): ) self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) + def _init_hf_config_and_tf_config( + self, + model_path, + tokenizer_or_path, + dtype, + override_model_config, + override_transformer_config, + trust_remote_code=False, + use_mbridge=False, + ): + from transformers import AutoConfig + from verl.models.mcore import hf_to_mcore_config + from verl.utils import hf_processor, hf_tokenizer + from verl.utils.fs import copy_to_local + from verl.utils.model import update_model_config + + # Step 1: initialize the tokenizer + self.local_path = copy_to_local(model_path) + if tokenizer_or_path is None: + self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code) + self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code) + elif isinstance(tokenizer_or_path, str): + self.tokenizer = hf_tokenizer( + copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code + ) + self.processor = hf_processor( + copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code + ) + else: + self.tokenizer = tokenizer_or_path + self.processor = tokenizer_or_path + + if self.config.model.get("custom_chat_template", None) is not None: + if self.processor is not None: + self.processor.chat_template = self.config.model.custom_chat_template + else: + self.tokenizer.chat_template = self.config.model.custom_chat_template + + # Step 2: get the hf + hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code) + + # Step 3: override the hf config + override_config_kwargs = { + "bos_token_id": self.tokenizer.bos_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + "pad_token_id": self.tokenizer.pad_token_id, + } + override_config_kwargs.update(override_model_config.get("model_config", {})) + + # patch for rope + if self.config.model.rope_scaling is not None: + hf_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling) + if self.config.model.rope_theta is not None: + hf_config.rope_theta = self.config.model.rope_theta + + self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False) + update_model_config(hf_config, override_config_kwargs=override_config_kwargs) + self.architectures = getattr(hf_config, "architectures", None) + if self.rank == 0: + print(f"Model config after override: {hf_config}") + tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) + + if use_mbridge: + from verl.models.mcore.mbridge import AutoBridge + + bridge = AutoBridge.from_config(hf_config) + bridge.set_extra_args(**override_transformer_config) + tf_config = bridge.config + self.bridge = bridge + else: + self.bridge = None + + print(f"TF config: {tf_config}") + self.hf_config = hf_config + self.tf_config = tf_config + def _build_model_optimizer( self, model_path, optim_config, override_model_config, override_transformer_config ):