Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ class ModelConfig:
class InferenceModelConfig:
# ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path
model_path: Optional[str] = None
name: Optional[str] = None

engine_type: str = "vllm"
engine_num: int = 1
Expand Down
14 changes: 14 additions & 0 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def get_model_path(self) -> Optional[str]:
"""Get the model path"""
return None

def get_model_name(self) -> Optional[str]:
"""Get the name of the model."""
return None


def _history_recorder(func):
"""Decorator to record history of the model calls."""
Expand Down Expand Up @@ -279,6 +283,16 @@ async def model_path_async(self) -> str:
"""Get the model path."""
return await self.model.get_model_path.remote()

@property
def model_name(self) -> Optional[str]:
"""Get the name of the model."""
return ray.get(self.model.get_model_name.remote())

@property
async def model_name_async(self) -> Optional[str]:
"""Get the name of the model."""
return await self.model.get_model_name.remote()

def get_lora_request(self) -> Any:
if self.enable_lora:
return ray.get(self.model.get_lora_request.remote())
Expand Down
4 changes: 4 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
import vllm
from vllm.sampling_params import RequestOutputKind

self.name = config.name
self.logger = get_logger(__name__)
self.vllm_version = get_vllm_version()
self.config = config
Expand Down Expand Up @@ -718,6 +719,9 @@ def get_model_version(self) -> int:
def get_model_path(self) -> str:
return self.config.model_path # type: ignore [return-value]

def get_model_name(self) -> Optional[str]:
return self.name # type: ignore [return-value]

def get_lora_request(self, lora_path: Optional[str] = None) -> Any:
from vllm.lora.request import LoRARequest

Expand Down
19 changes: 12 additions & 7 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,18 @@ def __init__(
"top_logprobs": self.task.rollout_args.logprobs,
},
)
self.auxiliary_chat_models = [
TrinityChatModel(
openai_async_client=aux_model,
# TODO: customize generate_kwargs for auxiliary models if needed
)
for aux_model in (self.auxiliary_models or [])
]

# TODO: customize generate_kwargs for auxiliary models if needed
if self.auxiliary_model_wrappers is not None and self.auxiliary_models is not None:
self.auxiliary_chat_models = {
aux_model_wrapper.model_name
or f"auxiliary_model_{i}": TrinityChatModel(openai_async_client=aux_model)
for i, (aux_model_wrapper, aux_model) in enumerate(
zip(self.auxiliary_model_wrappers, self.auxiliary_models)
)
}
else:
self.auxiliary_chat_models = {}

def construct_experiences(
self,
Expand Down