Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest

import ray
Expand Down Expand Up @@ -935,3 +936,59 @@ async def test_api_tool_calls(self):
print_debug(
"\n" + "=" * 28 + f" test_api_tool_calls PASSED in {total_time:.2f}s " + "=" * 28 + "\n"
)


class TestSuperLongGeneration(RayUnittestBaseAysnc):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.model.max_model_len = 81920
self.config.model.max_prompt_tokens = 61440
self.config.model.max_response_tokens = 20480
self.config.model.rope_scaling = {
"rope_type": "yarn",
"factor": 2.0,
"original_max_position_embeddings": 40960,
}
self.config.explorer.rollout_model.engine_type = "vllm"
self.config.explorer.rollout_model.engine_num = 1
self.config.explorer.rollout_model.tensor_parallel_size = 1
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE

self.config.check_and_update()
self.engines, self.auxiliary_engines = create_inference_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True)

async def test_generate(self):
base_dir = os.path.dirname(__file__)
target_dir = os.path.join(base_dir, "..", "..", "trinity", "trainer", "verl")
with open(os.path.join(target_dir, "fsdp_workers.py")) as f:
fsdp_code = f.read()
with open(os.path.join(target_dir, "megatron_workers.py")) as f:
megatron_code = f.read()
target_dir = os.path.join(base_dir, "..", "..", "trinity", "common")
with open(os.path.join(target_dir, "config.py")) as f:
config_code = f.read()
target_dir = os.path.join(base_dir, "..", "..", "trinity", "manager")
with open(os.path.join(target_dir, "config_manager.py")) as f:
config_manager_code = f.read()

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": """# Please add comments and documentation for these following code, """
"""make sure the code is well-structured and easy to read, """
"""and the complete code must be shown, do not omit any parts.\n"""
f"""## fsdp_workers.py\n{fsdp_code}\n"""
f"""## megatron_workers.py\n{megatron_code}\n"""
f"""## config.py\n{config_code}\n"""
f"""## config_manager.py\n{config_manager_code}\n""",
},
]
response = self.model_wrapper.chat(messages, n=1, temperature=0.7, logprobs=True)[0]
self.assertGreater(
response.prompt_length, 40960
) # If not long enough, please add more files to prompt
self.assertGreater(response.logprobs.shape[0], 1000)
6 changes: 6 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ class TestTrainerCountdown(BaseTrainerCase):
def test_trainer(self):
"""Test the both and bench mode."""
# test both mode
self.config.model.rope_scaling = {
"rope_type": "yarn",
"factor": 2.0,
"original_max_position_embeddings": 16384,
}
self.config.model.rope_theta = 10000
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.taskset.task_selector = TaskSelectorConfig(
selector_type="shuffle", seed=42
Expand Down
14 changes: 12 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ class ModelConfig:
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None

# rope config
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None


@dataclass
class InferenceModelConfig:
Expand Down Expand Up @@ -498,6 +502,10 @@ class InferenceModelConfig:
lora_modules: Optional[List[Dict]] = None
lora_kwargs: Optional[dict] = field(default_factory=dict)

# ! DO NOT SET, rope config
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None


@dataclass
class AlgorithmConfig:
Expand Down Expand Up @@ -1190,12 +1198,14 @@ def check_and_update(self) -> Config: # noqa: C901
"max_response_tokens",
"min_response_tokens",
]
for args in ["model_path"] + rollout_args + length_args:
rope_args = ["rope_scaling", "rope_theta"]
model_args = rollout_args + length_args + rope_args
for args in ["model_path"] + model_args:
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
for args in rollout_args + length_args:
for args in model_args:
set_if_none(aux_model, args, getattr(self.model, args))

# for lora configs
Expand Down
6 changes: 6 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def __init__(
max_model_len = config.max_model_len
self.enable_lora = config.enable_lora
self.default_lora_path = config.lora_kwargs.pop("default_lora_path", None)
rope_kwargs = {
key: getattr(config, key)
for key in ["rope_scaling", "rope_theta"]
if getattr(config, key) is not None
}
engine_args = vllm.AsyncEngineArgs(
model=config.model_path,
enforce_eager=config.enforce_eager,
Expand All @@ -101,6 +106,7 @@ def __init__(
disable_log_stats=True,
enable_lora=config.enable_lora,
logprobs_mode="processed_logprobs",
**rope_kwargs,
**config.lora_kwargs,
)
if get_vllm_version() > parse_version("0.10.0"):
Expand Down
6 changes: 6 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ class ActorModel:
lora_alpha: int = 32
target_modules: Optional[str] = "all-linear"

# rope configs
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None


@dataclass
class Optim:
Expand Down Expand Up @@ -412,6 +416,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
# Actor / Rollout Config
self.actor_rollout_ref.model.path = config.model.model_path
self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template
self.actor_rollout_ref.model.rope_scaling = config.model.rope_scaling
self.actor_rollout_ref.model.rope_theta = config.model.rope_theta
self.actor_rollout_ref.actor.optim.total_training_steps = self.trainer.total_training_steps
self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size
self.actor_rollout_ref.rollout.temperature = (
Expand Down
6 changes: 6 additions & 0 deletions trinity/trainer/verl/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@ def _build_model_optimizer( # noqa: C901
local_path, trust_remote_code=trust_remote_code, attn_implementation="flash_attention_2"
)

# patch for rope
if self.config.model.rope_scaling is not None:
actor_model_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling)
if self.config.model.rope_theta is not None:
actor_model_config.rope_theta = self.config.model.rope_theta

# patch for kimi-vl
if getattr(actor_model_config, "model_type", None) == "kimi_vl":
actor_model_config.text_config.topk_method = "greedy"
Expand Down
76 changes: 76 additions & 0 deletions trinity/trainer/verl/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,82 @@ def __init__(self, config: DictConfig, role: str, **kwargs):
)
self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False)

def _init_hf_config_and_tf_config(
self,
model_path,
tokenizer_or_path,
dtype,
override_model_config,
override_transformer_config,
trust_remote_code=False,
use_mbridge=False,
):
from transformers import AutoConfig
from verl.models.mcore import hf_to_mcore_config
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.model import update_model_config

# Step 1: initialize the tokenizer
self.local_path = copy_to_local(model_path)
if tokenizer_or_path is None:
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)
self.processor = hf_processor(self.local_path, trust_remote_code=trust_remote_code)
elif isinstance(tokenizer_or_path, str):
self.tokenizer = hf_tokenizer(
copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code
)
self.processor = hf_processor(
copy_to_local(tokenizer_or_path), trust_remote_code=trust_remote_code
)
else:
self.tokenizer = tokenizer_or_path
self.processor = tokenizer_or_path

if self.config.model.get("custom_chat_template", None) is not None:
if self.processor is not None:
self.processor.chat_template = self.config.model.custom_chat_template
else:
self.tokenizer.chat_template = self.config.model.custom_chat_template

# Step 2: get the hf
hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)

# Step 3: override the hf config
override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(override_model_config.get("model_config", {}))

# patch for rope
if self.config.model.rope_scaling is not None:
hf_config.rope_scaling = OmegaConf.to_container(self.config.model.rope_scaling)
if self.config.model.rope_theta is not None:
hf_config.rope_theta = self.config.model.rope_theta

self.share_embeddings_and_output_weights = getattr(hf_config, "tie_word_embeddings", False)
update_model_config(hf_config, override_config_kwargs=override_config_kwargs)
self.architectures = getattr(hf_config, "architectures", None)
if self.rank == 0:
print(f"Model config after override: {hf_config}")
tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)

if use_mbridge:
from verl.models.mcore.mbridge import AutoBridge

bridge = AutoBridge.from_config(hf_config)
bridge.set_extra_args(**override_transformer_config)
tf_config = bridge.config
self.bridge = bridge
else:
self.bridge = None

print(f"TF config: {tf_config}")
self.hf_config = hf_config
self.tf_config = tf_config

def _build_model_optimizer(
self, model_path, optim_config, override_model_config, override_transformer_config
):
Expand Down