Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trinity/algorithm/key_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ def from_trinity(self, key: str) -> str:
"advantages": "advantages",
}
),
"tinker": KeyMapper({}),
}
231 changes: 152 additions & 79 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,17 @@ class DataProcessorConfig:
)


@dataclass
class TinkerConfig:
enable: bool = False
base_model: Optional[str] = None
rank: int = 32 # lora rank
seed: Optional[int] = None
train_mlp: bool = True
train_attn: bool = True
train_unembed: bool = True


@dataclass
class ModelConfig:
# source model path
Expand Down Expand Up @@ -472,6 +483,9 @@ class ModelConfig:
rope_scaling: Optional[dict] = None
rope_theta: Optional[float] = None

# tinker config
tinker: TinkerConfig = field(default_factory=TinkerConfig)


@dataclass
class InferenceModelConfig:
Expand Down Expand Up @@ -1146,6 +1160,9 @@ def _check_model(self) -> None:
if not model.critic_model_path:
model.critic_model_path = model.model_path

if model.tinker.enable:
self._check_tinker()

# check template
if model.chat_template_path is not None and model.custom_chat_template is None:
try:
Expand All @@ -1157,7 +1174,48 @@ def _check_model(self) -> None:
)

# check max_model_len, max_prompt_tokens, max_response_tokens
self._check_model_len()

def _check_tinker(self) -> None:
model = self.model
from trinity.algorithm import ALGORITHM_TYPE

algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type)
if algorithm.use_critic:
raise ValueError("Critic model is not supported when using tinker!")

set_if_none(model.tinker, "base_model", model.model_path)

import tinker

service_client = tinker.ServiceClient()
supported_models = {
item.model_name for item in service_client.get_server_capabilities().supported_models
}
if model.tinker.base_model not in supported_models:
print(supported_models)
raise ValueError(f"{model.tinker.base_model} is not supported by tinker!")
if model.tinker.base_model != model.model_path:
logger.warning(
f"The local tokenizer will use {model.model_path}, while tinker will use {model.tinker.base_model}"
)

if self.explorer.rollout_model.engine_type != "tinker":
self.explorer.rollout_model.engine_type = "tinker"
logger.warning("Rollout model engine type is set to `tinker`.")

if self.trainer.trainer_type != "tinker":
self.trainer.trainer_type = "tinker"
logger.warning("Trainer type is set to `tinker`.")

if self.synchronizer.sync_method == SyncMethod.NCCL:
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
"Tinker do not support NCCL, `synchronizer.sync_method` is set to `checkpoint`."
)

def _check_model_len(self) -> None:
model = self.model
# if all three are set, check if they are valid
if (
model.max_model_len is not None
Expand Down Expand Up @@ -1222,6 +1280,84 @@ def _check_model(self) -> None:
"`enable_prompt_truncation` is set to False; please make sure the prompt is not too long and `max_model_len` is large enough, otherwise prompt length + response length may exceed `max_model_len`!"
)

def _check_explorer(self) -> None:
rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"]
length_args = [
"max_model_len",
"max_prompt_tokens",
"max_response_tokens",
"min_response_tokens",
"enable_prompt_truncation",
]
rope_args = ["rope_scaling", "rope_theta"]
model_args = rollout_args + length_args + rope_args
model_path = (
self.model.tinker.base_model
if self.explorer.rollout_model.engine_type == "tinker"
else self.model.model_path
)
set_if_none(self.explorer.rollout_model, "model_path", model_path)
for args in model_args:
set_if_none(self.explorer.rollout_model, args, getattr(self.model, args))
if (
self.explorer.rollout_model.chat_template is None
and self.model.custom_chat_template is not None
):
self.explorer.rollout_model.chat_template = self.model.custom_chat_template
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
for args in model_args:
set_if_none(aux_model, args, getattr(self.model, args))

if self.explorer.over_rollout.ratio > 0.0:
if not (0.0 <= self.explorer.over_rollout.ratio < 1.0):
raise ValueError("over_rollout_ratio should be in [0.0, 1.0)")
if self.synchronizer.sync_style == SyncStyle.FIXED:
raise ValueError(
"over_rollout_ratio is not compatible with fixed sync_style, please set "
"`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`."
)

# for lora configs
if not self.model.tinker.enable and self.model.lora_configs is not None:
self.explorer.rollout_model.enable_lora = True
if len(self.model.lora_configs) > 1:
raise ValueError("Only one lora adapter is supported for now.")
if self.model.lora_configs[0].path is None:
logger.info("Creating dummy lora, since no lora_path is provided.")
lora_path = create_dummy_lora(
model_path=self.model.model_path,
checkpoint_job_dir=self.checkpoint_job_dir,
lora_rank=self.model.lora_configs[0].lora_rank,
lora_alpha=self.model.lora_configs[0].lora_alpha,
target_modules=self.model.lora_configs[0].target_modules,
)
self.model.lora_configs[0].path = lora_path
self.explorer.rollout_model.lora_modules = [
{
"lora_int_id": i + 1,
"lora_name": cfg.name,
"lora_path": cfg.path,
"base_model_name": cfg.base_model_name,
}
for i, cfg in enumerate(self.model.lora_configs)
]
self.explorer.rollout_model.lora_kwargs = {
"max_loras": len(self.model.lora_configs),
"max_lora_rank": max(
(
model_config.lora_rank
for model_config in self.model.lora_configs
if model_config.lora_rank > 0
),
default=0,
),
"default_lora_path": os.path.join(
self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter"
), # will be poped later
}

def __iter__(self):
"""Iterate over configs with each stage applied in order.

Expand Down Expand Up @@ -1288,91 +1424,25 @@ def check_and_update(self) -> Config: # noqa: C901

# check explorer
if self.explorer is not None:
rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"]
length_args = [
"max_model_len",
"max_prompt_tokens",
"max_response_tokens",
"min_response_tokens",
"enable_prompt_truncation",
]
rope_args = ["rope_scaling", "rope_theta"]
model_args = rollout_args + length_args + rope_args
for args in ["model_path"] + model_args:
set_if_none(self.explorer.rollout_model, args, getattr(self.model, args))
if (
self.explorer.rollout_model.chat_template is None
and self.model.custom_chat_template is not None
):
self.explorer.rollout_model.chat_template = self.model.custom_chat_template
for aux_model in self.explorer.auxiliary_models:
if not aux_model.model_path:
raise ValueError("auxiliary model's model_path is required.")
for args in model_args:
set_if_none(aux_model, args, getattr(self.model, args))

if self.explorer.over_rollout.ratio > 0.0:
if not (0.0 <= self.explorer.over_rollout.ratio < 1.0):
raise ValueError("over_rollout_ratio should be in [0.0, 1.0)")
if self.synchronizer.sync_style == SyncStyle.FIXED:
raise ValueError(
"over_rollout_ratio is not compatible with fixed sync_style, please set "
"`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`."
)

# for lora configs
if self.model.lora_configs is not None:
self.explorer.rollout_model.enable_lora = True
if len(self.model.lora_configs) > 1:
raise ValueError("Only one lora adapter is supported for now.")
if self.model.lora_configs[0].path is None:
logger.info("Creating dummy lora, since no lora_path is provided.")
lora_path = create_dummy_lora(
model_path=self.model.model_path,
checkpoint_job_dir=self.checkpoint_job_dir,
lora_rank=self.model.lora_configs[0].lora_rank,
lora_alpha=self.model.lora_configs[0].lora_alpha,
target_modules=self.model.lora_configs[0].target_modules,
)
self.model.lora_configs[0].path = lora_path
self.explorer.rollout_model.lora_modules = [
{
"lora_int_id": i + 1,
"lora_name": cfg.name,
"lora_path": cfg.path,
"base_model_name": cfg.base_model_name,
}
for i, cfg in enumerate(self.model.lora_configs)
]
self.explorer.rollout_model.lora_kwargs = {
"max_loras": len(self.model.lora_configs),
"max_lora_rank": max(
(
model_config.lora_rank
for model_config in self.model.lora_configs
if model_config.lora_rank > 0
),
default=0,
),
"default_lora_path": os.path.join(
self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter"
), # will be poped later
}
self._check_explorer()

# check synchronizer
self.synchronizer.ray_namespace = self.ray_namespace
self.synchronizer.explorer_world_size = (
self.explorer.rollout_model.engine_num
* self.explorer.rollout_model.tensor_parallel_size
)
if (
self.mode in ["train", "explore", "bench", "serve"]
and self.synchronizer.sync_method == SyncMethod.NCCL
):
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
if self.synchronizer.sync_method == SyncMethod.NCCL:
if self.mode in ["train", "explore", "bench", "serve"]:
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`."
)
if self.model.lora_configs is not None:
self.synchronizer.sync_method = SyncMethod.CHECKPOINT
logger.warning(
"LoRA is not supported with NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`."
)

self._check_interval()

Expand Down Expand Up @@ -1424,9 +1494,12 @@ def check_and_update(self) -> Config: # noqa: C901
f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, "
"must be one of 'last', 'always', or 'never'."
)
elif self.trainer.trainer_type == "tinker":
self.trainer.trainer_config = None
else:
raise ValueError(f"Invalid trainer type: {self.trainer_type}")
self.trainer.trainer_config.synchronize_config(self)
if self.trainer.trainer_config:
self.trainer.trainer_config.synchronize_config(self)

# check service
if self.service.data_juicer is not None:
Expand Down
28 changes: 28 additions & 0 deletions trinity/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def create_inference_models(
from ray.util.placement_group import placement_group, placement_group_table
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from trinity.common.models.tinker_model import TinkerModel
from trinity.common.models.vllm_model import vLLMRolloutModel

logger = get_logger(__name__)
Expand All @@ -54,6 +55,33 @@ def create_inference_models(
rollout_engines = []
if config.explorer.rollout_model.engine_type.startswith("vllm"):
engine_cls = vLLMRolloutModel
elif config.explorer.rollout_model.engine_type == "tinker":
engine_cls = TinkerModel
namespace = ray.get_runtime_context().namespace
rollout_engines = [
ray.remote(engine_cls)
.options(
name=f"{config.explorer.name}_rollout_model_{i}",
namespace=namespace,
)
.remote(
config=config.explorer.rollout_model,
)
for i in range(engine_num)
]
auxiliary_engines = [
ray.remote(engine_cls)
.options(
name=f"{config.explorer.name}_auxiliary_model_{i}_{j}",
namespace=namespace,
)
.remote(
config=config.explorer.auxiliary_models[i],
)
for i, model_config in enumerate(config.explorer.auxiliary_models)
for j in range(model_config.engine_num)
]
return rollout_engines, auxiliary_engines
else:
raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}")

Expand Down
12 changes: 9 additions & 3 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ async def prepare(self) -> None:
"""Prepare the model before inference."""
pass

@abstractmethod
async def sync_model(self, model_version: int) -> int:
"""Sync the model with the latest model_version."""

@abstractmethod
def get_model_version(self) -> int:
"""Get the checkpoint version."""
Expand Down Expand Up @@ -105,7 +109,9 @@ def __init__(
enable_history (bool): Whether to enable history recording. Default to False.
enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models.
"""
assert engine_type.startswith("vllm"), "Only vLLM model is supported for now."
assert (
engine_type.startswith("vllm") or engine_type == "tinker"
), "Only vLLM and tinker model is supported for now."
self.model = model
self.api_address: str = None
self.openai_client: openai.OpenAI = None
Expand Down Expand Up @@ -205,13 +211,13 @@ async def generate_mm_async(
def chat(self, messages: List[dict], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of messages."""
lora_request = self.get_lora_request()
return ray.get(self.model.chat.remote(messages, lora_request, **kwargs))
return ray.get(self.model.chat.remote(messages, lora_request=lora_request, **kwargs))

@_history_recorder
async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of messages in async."""
lora_request = await self.get_lora_request_async()
return await self.model.chat.remote(messages, lora_request, **kwargs)
return await self.model.chat.remote(messages, lora_request=lora_request, **kwargs)

@_history_recorder
def chat_mm(
Expand Down
Loading