Skip to content

Commit bb0875d

Browse files
Add model_name for auxiliary models (#461)
Co-authored-by: chenyushuo <[email protected]>
1 parent 7601bb9 commit bb0875d

File tree

5 files changed

+32
-8
lines changed

5 files changed

+32
-8
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ repos:
1515
rev: 23.7.0
1616
hooks:
1717
- id: black
18-
language_version: python3.12
1918
args: [--line-length=100]
2019

2120
- repo: https://github.com/pycqa/isort

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ class ModelConfig:
490490
class InferenceModelConfig:
491491
# ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path
492492
model_path: Optional[str] = None
493+
name: Optional[str] = None
493494

494495
engine_type: str = "vllm"
495496
engine_num: int = 1

trinity/common/models/model.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def get_model_path(self) -> Optional[str]:
7070
"""Get the model path"""
7171
return None
7272

73+
def get_model_name(self) -> Optional[str]:
74+
"""Get the name of the model."""
75+
return None
76+
7377

7478
def _history_recorder(func):
7579
"""Decorator to record history of the model calls."""
@@ -113,6 +117,7 @@ def __init__(
113117
engine_type.startswith("vllm") or engine_type == "tinker"
114118
), "Only vLLM and tinker model is supported for now."
115119
self.model = model
120+
self._model_name = None
116121
self.api_address: str = None
117122
self.openai_client: openai.OpenAI = None
118123
self.openai_async_client: openai.AsyncOpenAI = None
@@ -128,6 +133,7 @@ def __init__(
128133

129134
async def prepare(self) -> None:
130135
"""Prepare the model wrapper."""
136+
self._model_name = await self.model.get_model_name.remote()
131137
self.api_address = await self.model.get_api_server_url.remote()
132138
if self.api_address is None:
133139
self.logger.info("API server is not enabled for inference model.")
@@ -285,6 +291,16 @@ async def model_path_async(self) -> str:
285291
"""Get the model path."""
286292
return await self.model.get_model_path.remote()
287293

294+
@property
295+
def model_name(self) -> Optional[str]:
296+
"""Get the name of the model."""
297+
return self._model_name
298+
299+
@property
300+
async def model_name_async(self) -> Optional[str]:
301+
"""Get the name of the model."""
302+
return self._model_name
303+
288304
def get_lora_request(self) -> Any:
289305
if self.enable_lora:
290306
return ray.get(self.model.get_lora_request.remote())

trinity/common/models/vllm_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,9 @@ def get_model_version(self) -> int:
718718
def get_model_path(self) -> str:
719719
return self.config.model_path # type: ignore [return-value]
720720

721+
def get_model_name(self) -> Optional[str]:
722+
return self.config.name # type: ignore [return-value]
723+
721724
def get_lora_request(self, lora_path: Optional[str] = None) -> Any:
722725
from vllm.lora.request import LoRARequest
723726

trinity/common/workflows/agentscope_workflow.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,18 @@ def __init__(
118118
"top_logprobs": self.task.rollout_args.logprobs,
119119
},
120120
)
121-
self.auxiliary_chat_models = [
122-
TrinityChatModel(
123-
openai_async_client=aux_model,
124-
# TODO: customize generate_kwargs for auxiliary models if needed
125-
)
126-
for aux_model in (self.auxiliary_models or [])
127-
]
121+
122+
# TODO: customize generate_kwargs for auxiliary models if needed
123+
if self.auxiliary_model_wrappers is not None and self.auxiliary_models is not None:
124+
self.auxiliary_chat_models = {
125+
aux_model_wrapper.model_name
126+
or f"auxiliary_model_{i}": TrinityChatModel(openai_async_client=aux_model)
127+
for i, (aux_model_wrapper, aux_model) in enumerate(
128+
zip(self.auxiliary_model_wrappers, self.auxiliary_models)
129+
)
130+
}
131+
else:
132+
self.auxiliary_chat_models = {}
128133

129134
def construct_experiences(
130135
self,

0 commit comments

Comments
 (0)