Skip to content

Commit af35c31

Browse files
committed
add generation args to ModelConfig
1 parent 6ff5195 commit af35c31

File tree

4 files changed

+73
-18
lines changed

4 files changed

+73
-18
lines changed

trinity/common/config.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ class FormatConfig:
7878

7979
@dataclass
8080
class GenerationConfig:
81-
temperature: float = 1.0
82-
top_p: float = 1.0
83-
top_k: int = -1
84-
logprobs: int = 0 # vLLM return `logprobs + 1` elements
81+
temperature: Optional[float] = None # 1.0
82+
top_p: Optional[float] = None # 1.0
83+
top_k: Optional[int] = None # -1
84+
logprobs: Optional[int] = None # 0 # vLLM return `logprobs + 1` elements
8585
max_tokens: Optional[int] = None # if None, use model.max_response_tokens
8686
# repeat each task for `n` times
8787
# ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args`
@@ -412,6 +412,12 @@ class ModelConfig:
412412

413413
custom_chat_template: Optional[str] = None
414414

415+
# rollout args
416+
temperature: float = 1.0
417+
top_p: float = 1.0
418+
top_k: int = -1
419+
logprobs: int = 0
420+
415421
# the total number of tokens the model can handle
416422
max_model_len: Optional[int] = None
417423

@@ -447,6 +453,12 @@ class InferenceModelConfig:
447453
dtype: str = "bfloat16"
448454
seed: int = 42
449455

456+
# rollout args, ! DO NOT SET
457+
temperature: Optional[float] = None
458+
top_p: Optional[float] = None
459+
top_k: Optional[int] = None
460+
logprobs: Optional[int] = None
461+
450462
# if not set, use `model.max_model_len`
451463
max_model_len: Optional[int] = None
452464
# if not set, use `model.max_prompt_tokens`
@@ -853,6 +865,10 @@ def _check_explorer_input(self) -> None:
853865
set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type)
854866
set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
855867
set_if_none(taskset, "ray_namespace", self.ray_namespace)
868+
set_if_none(taskset.rollout_args, "temperature", self.model.temperature)
869+
set_if_none(taskset.rollout_args, "top_p", self.model.top_p)
870+
set_if_none(taskset.rollout_args, "top_k", self.model.top_k)
871+
set_if_none(taskset.rollout_args, "logprobs", self.model.logprobs)
856872
set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens)
857873

858874
for idx, dataset in enumerate(explorer_input.eval_tasksets):
@@ -868,6 +884,10 @@ def _check_explorer_input(self) -> None:
868884
set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type)
869885
set_if_none(dataset, "default_reward_fn_type", explorer_input.default_reward_fn_type)
870886
set_if_none(dataset, "ray_namespace", self.ray_namespace)
887+
set_if_none(dataset.rollout_args, "temperature", self.model.temperature)
888+
set_if_none(dataset.rollout_args, "top_p", self.model.top_p)
889+
set_if_none(dataset.rollout_args, "top_k", self.model.top_k)
890+
set_if_none(dataset.rollout_args, "logprobs", self.model.logprobs)
871891
set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens)
872892

873893
def _check_trainer_input(self) -> None:
@@ -1161,18 +1181,30 @@ def check_and_update(self) -> Config: # noqa: C901
11611181

11621182
# check explorer
11631183
if self.explorer is not None:
1164-
self.explorer.rollout_model.model_path = self.model.model_path
1165-
self.explorer.rollout_model.max_model_len = self.model.max_model_len
1166-
self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
1167-
self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
1168-
self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens
1184+
rollout_args = ["temperature", "top_p", "top_k", "logprobs"]
1185+
length_args = [
1186+
"max_model_len",
1187+
"max_prompt_tokens",
1188+
"max_response_tokens",
1189+
"min_response_tokens",
1190+
]
1191+
for args in ["model_path"] + rollout_args + length_args:
1192+
setattr(self.explorer.rollout_model, args, getattr(self.model, args))
1193+
# self.explorer.rollout_model.model_path = self.model.model_path
1194+
# self.explorer.rollout_model.temperature = self.model.temperature
1195+
# self.explorer.rollout_model.max_model_len = self.model.max_model_len
1196+
# self.explorer.rollout_model.max_prompt_tokens = self.model.max_prompt_tokens
1197+
# self.explorer.rollout_model.max_response_tokens = self.model.max_response_tokens
1198+
# self.explorer.rollout_model.min_response_tokens = self.model.min_response_tokens
11691199
for aux_model in self.explorer.auxiliary_models:
11701200
if not aux_model.model_path:
11711201
raise ValueError("auxiliary model's model_path is required.")
1172-
set_if_none(aux_model, "max_model_len", self.model.max_model_len)
1173-
set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens)
1174-
set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens)
1175-
set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens)
1202+
for args in rollout_args + length_args:
1203+
set_if_none(aux_model, args, getattr(self.model, args))
1204+
# set_if_none(aux_model, "max_model_len", self.model.max_model_len)
1205+
# set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens)
1206+
# set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens)
1207+
# set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens)
11761208

11771209
# for lora configs
11781210
if self.model.lora_configs is not None:

trinity/common/models/model.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def get_model_path(self) -> Optional[str]:
5959
"""Get the model path"""
6060
return None
6161

62+
def get_default_rollout_args(self) -> dict:
63+
"""Get the default rollout arguments."""
64+
raise NotImplementedError
65+
6266

6367
def _history_recorder(func):
6468
"""Decorator to record history of the model calls."""
@@ -83,7 +87,7 @@ class ModelWrapper:
8387

8488
def __init__(
8589
self,
86-
model: Any,
90+
model: InferenceModel,
8791
engine_type: str = "vllm",
8892
enable_lora: bool = False,
8993
enable_history: bool = False,
@@ -99,6 +103,8 @@ def __init__(
99103
self.history = []
100104
self.status = RunningStatus.RUNNING
101105
self.request_count = 0
106+
self.default_rollout_args = model.get_default_rollout_args()
107+
self.default_rollout_args["logprobs"] = True
102108

103109
async def prepare(self) -> None:
104110
"""Prepare the model wrapper."""
@@ -271,10 +277,12 @@ def get_openai_client(self) -> openai.OpenAI:
271277
)
272278
if self.enable_history:
273279
# add a decorator to the openai client to record history
274-
ori_create = partial(self.openai_client.chat.completions.create, logprobs=True)
280+
ori_create = self.openai_client.chat.completions.create
275281

276282
def record_chat_completions(*args, **kwargs):
277-
response = ori_create(*args, **kwargs)
283+
default_kwargs = self.default_rollout_args.copy()
284+
default_kwargs.update(kwargs)
285+
response = ori_create(*args, **default_kwargs)
278286
self.history.extend(convert_api_output_to_experience(response))
279287
return response
280288

@@ -301,10 +309,13 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI:
301309
)
302310
if self.enable_history:
303311
# add a decorator to the openai client to record history
304-
ori_create = partial(self.openai_async_client.chat.completions.create, logprobs=True)
312+
ori_create = self.openai_async_client.chat.completions.create
305313

306314
async def record_chat_completions(*args, **kwargs):
307-
response = await ori_create(*args, **kwargs)
315+
default_kwargs = self.default_rollout_args.copy()
316+
default_kwargs.update(kwargs)
317+
# print(f"!!!!! {default_kwargs = }")
318+
response = ori_create(*args, **default_kwargs)
308319
self.history.extend(convert_api_output_to_experience(response))
309320
return response
310321

trinity/common/models/vllm_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,15 @@ def get_model_version(self) -> int:
503503
def get_model_path(self) -> str:
504504
return self.config.model_path
505505

506+
def get_default_rollout_args(self) -> dict:
507+
return {
508+
"temperature": self.config.temperature,
509+
"top_p": self.config.top_p,
510+
"top_k": self.config.top_k,
511+
"max_tokens": self.config.max_response_tokens,
512+
# "n": self.config.repeat_times,
513+
}
514+
506515
def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest:
507516
assert self.config.lora_modules is not None
508517
lora_request = LoRARequest(**self.config.lora_modules[0])

trinity/common/workflows/agentscope_workflow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def __init__(
4747
model.get_openai_async_client(),
4848
generate_kwargs={
4949
"temperature": self.task.rollout_args.temperature,
50+
"top_p": self.task.rollout_args.top_p,
51+
"top_k": self.task.rollout_args.top_k,
5052
"max_tokens": self.task.rollout_args.max_tokens or 4096,
53+
"logprobs": True,
5154
"top_logprobs": self.task.rollout_args.logprobs,
5255
},
5356
)

0 commit comments

Comments
 (0)