Skip to content

Commit 33a94bb

Browse files
authored
Simplify InferenceModel (#485)
1 parent 9eb1244 commit 33a94bb

File tree

3 files changed

+132
-214
lines changed

3 files changed

+132
-214
lines changed

trinity/common/models/model.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from trinity.common.config import InferenceModelConfig
1717
from trinity.common.constants import RunningStatus
1818
from trinity.common.experience import Experience
19+
from trinity.common.models.utils import get_action_mask_method
1920
from trinity.utils.log import get_logger
2021

2122

@@ -84,6 +85,123 @@ def get_model_path(self) -> Optional[str]:
8485
return self.config.model_path
8586

8687

88+
class BaseInferenceModel(InferenceModel):
89+
"""Base class for inference models containing common logic."""
90+
91+
def __init__(self, config: InferenceModelConfig) -> None:
92+
super().__init__(config)
93+
self.tokenizer = None
94+
self.chat_template = None
95+
if self.config.chat_template:
96+
self.chat_template = self.config.chat_template
97+
self.action_mask_method = get_action_mask_method(self.chat_template)
98+
self.enable_thinking = config.enable_thinking
99+
100+
def apply_chat_template(
101+
self,
102+
tokenizer_or_processor,
103+
messages: List[dict],
104+
) -> str:
105+
assert tokenizer_or_processor is not None, "tokenizer_or_processor must be provided."
106+
if self.chat_template is None:
107+
assert self.tokenizer is not None, "self.tokenizer must be initialized."
108+
self.chat_template = self.tokenizer.get_chat_template()
109+
110+
if messages[-1]["role"] == "assistant":
111+
prompt = tokenizer_or_processor.apply_chat_template(
112+
messages,
113+
tokenize=False,
114+
continue_final_message=True,
115+
chat_template=self.chat_template,
116+
)
117+
else:
118+
prompt = tokenizer_or_processor.apply_chat_template(
119+
messages,
120+
tokenize=False,
121+
add_generation_prompt=True,
122+
chat_template=self.chat_template,
123+
enable_thinking=self.enable_thinking,
124+
)
125+
return prompt
126+
127+
def _handle_prompt_truncation(self, prompt: str, **kwargs) -> Tuple[Sequence, bool]:
128+
"""Handle prompt truncation if needed."""
129+
# Tokenize once without truncation to check if truncation is needed
130+
token_ids = self.tokenizer( # type: ignore
131+
prompt,
132+
truncation=False,
133+
return_tensors="pt",
134+
)[
135+
"input_ids"
136+
][0].tolist()
137+
138+
# Check if truncation is needed and apply it
139+
if (
140+
self.config.enable_prompt_truncation
141+
and self.config.max_prompt_tokens is not None
142+
and len(token_ids) > self.config.max_prompt_tokens
143+
):
144+
self.logger.warning(f"Prompt was truncated to {self.config.max_prompt_tokens} tokens")
145+
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
146+
return [
147+
Experience(
148+
tokens=token_ids,
149+
logprobs=torch.zeros(1, dtype=torch.float32),
150+
prompt_length=len(token_ids) - 1,
151+
prompt_text=self.tokenizer.decode(token_ids[:-1]),
152+
response_text=self.tokenizer.decode(token_ids[-1]),
153+
truncate_status="prompt_truncated",
154+
reward=0.0,
155+
)
156+
for _ in range(kwargs.get("n", 1))
157+
], False
158+
return token_ids, True
159+
160+
async def convert_messages_to_experience(
161+
self,
162+
messages: List[dict],
163+
tools: Optional[List[dict]] = None,
164+
temperature: Optional[float] = None,
165+
) -> Experience:
166+
"""Convert a list of messages into an experience in async."""
167+
if self.tokenizer is None:
168+
await self._initialize_tokenizer()
169+
if self.chat_template is None:
170+
self.chat_template = self.tokenizer.get_chat_template()
171+
token_ids, action_mask, prompt_length = self.action_mask_method(
172+
tokenizer=self.tokenizer,
173+
messages=messages,
174+
tools=tools,
175+
chat_template=self.chat_template,
176+
enable_thinking=self.enable_thinking,
177+
) # (seq_length, ), (seq_length, )
178+
179+
# Truncate tokens if they exceed the length limit
180+
assert token_ids is not None
181+
truncate_status = None
182+
if self.config.max_model_len is not None and self.config.max_model_len > 0:
183+
if len(token_ids) > self.config.max_model_len - 1:
184+
truncate_status = "response_truncated"
185+
self.logger.warning(
186+
f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}"
187+
)
188+
token_ids = token_ids[: self.config.max_model_len - 1]
189+
action_mask = action_mask[: self.config.max_model_len - 1]
190+
191+
temperature = temperature if temperature is not None else self.config.temperature
192+
logprobs = await self.logprobs(
193+
token_ids=token_ids.tolist(), temperature=temperature
194+
) # (seq_length - 1,)
195+
return Experience(
196+
tokens=token_ids,
197+
logprobs=logprobs[prompt_length - 1 :],
198+
prompt_length=prompt_length,
199+
action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
200+
messages=messages,
201+
truncate_status=truncate_status,
202+
)
203+
204+
87205
def _history_recorder(func):
88206
"""Decorator to record history of the model calls."""
89207

trinity/common/models/tinker_model.py

Lines changed: 6 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010

1111
from trinity.common.config import InferenceModelConfig
1212
from trinity.common.experience import Experience
13-
from trinity.common.models.model import InferenceModel
14-
from trinity.common.models.utils import get_action_mask_method
13+
from trinity.common.models.model import BaseInferenceModel
1514
from trinity.manager.synchronizer import Synchronizer
1615

1716

18-
class TinkerModel(InferenceModel):
17+
class TinkerModel(BaseInferenceModel):
1918
def __init__(
2019
self,
2120
config: InferenceModelConfig,
@@ -25,12 +24,6 @@ def __init__(
2524
self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace)
2625
self.model = None
2726
self.model_path = config.model_path
28-
self.tokenizer = None
29-
self.chat_template = None
30-
if self.config.chat_template:
31-
self.chat_template = self.config.chat_template
32-
self.action_mask_method = get_action_mask_method(self.chat_template)
33-
self.enable_thinking = config.enable_thinking
3427

3528
async def _initialize_tokenizer(self) -> None:
3629
"""Initialize the tokenizer."""
@@ -62,34 +55,9 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
6255
if self.tokenizer is None:
6356
await self._initialize_tokenizer()
6457

65-
# Tokenize once without truncation to check if truncation is needed
66-
token_ids = self.tokenizer( # type: ignore
67-
prompt,
68-
truncation=False,
69-
return_tensors="pt",
70-
)[
71-
"input_ids"
72-
][0].tolist()
73-
74-
# Check if truncation is needed and apply it
75-
if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None:
76-
if len(token_ids) > self.config.max_prompt_tokens:
77-
self.logger.warning(
78-
f"Prompt was truncated to {self.config.max_prompt_tokens} tokens"
79-
)
80-
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
81-
return [
82-
Experience(
83-
tokens=token_ids,
84-
logprobs=torch.zeros(1, dtype=torch.float32),
85-
prompt_length=len(token_ids) - 1,
86-
prompt_text=self.tokenizer.decode(token_ids[:-1]),
87-
response_text=self.tokenizer.decode(token_ids[-1]),
88-
truncate_status="prompt_truncated",
89-
reward=0.0,
90-
)
91-
for _ in range(kwargs.get("n", 1))
92-
]
58+
token_ids, is_valid = self._handle_prompt_truncation(prompt, **kwargs)
59+
if not is_valid:
60+
return token_ids
9361

9462
with_chat_completion = kwargs.get("with_chat_completion", False)
9563
if with_chat_completion:
@@ -157,8 +125,6 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
157125
"""Generate experiences from a list of history chat messages in async."""
158126
if self.tokenizer is None:
159127
await self._initialize_tokenizer()
160-
if self.chat_template is None:
161-
self.chat_template = self.tokenizer.get_chat_template()
162128

163129
# TODO: this is a hack to support openai chat messages, which only supports text
164130
for msg in messages:
@@ -169,72 +135,14 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
169135
content_str = msg["content"]
170136
msg["content"] = content_str
171137

172-
if messages[-1]["role"] == "assistant":
173-
prompt = self.tokenizer.apply_chat_template(
174-
messages,
175-
tokenize=False,
176-
continue_final_message=True,
177-
chat_template=self.chat_template,
178-
)
179-
else:
180-
prompt = self.tokenizer.apply_chat_template(
181-
messages,
182-
tokenize=False,
183-
add_generation_prompt=True,
184-
chat_template=self.chat_template,
185-
enable_thinking=self.enable_thinking,
186-
)
138+
prompt = self.apply_chat_template(self.tokenizer, messages)
187139
return await self.generate(prompt=prompt, **kwargs)
188140

189141
async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor:
190142
"""Generate logprobs for a list of tokens in async."""
191143
logprobs = await self.model.compute_logprobs_async(types.ModelInput.from_ints(token_ids))
192144
return torch.tensor(logprobs[1:], dtype=torch.float32)
193145

194-
async def convert_messages_to_experience(
195-
self,
196-
messages: List[dict],
197-
tools: Optional[List[dict]] = None,
198-
temperature: Optional[float] = None,
199-
) -> Experience:
200-
"""Convert a list of messages into an experience in async."""
201-
if self.tokenizer is None:
202-
await self._initialize_tokenizer()
203-
if self.chat_template is None:
204-
self.chat_template = self.tokenizer.get_chat_template()
205-
token_ids, action_mask, prompt_length = self.action_mask_method(
206-
tokenizer=self.tokenizer,
207-
messages=messages,
208-
tools=tools,
209-
chat_template=self.chat_template,
210-
enable_thinking=self.enable_thinking,
211-
) # (seq_length, ), (seq_length, )
212-
213-
# Truncate tokens if they exceed the length limit
214-
assert token_ids is not None
215-
truncate_status = None
216-
if self.config.max_model_len is not None and self.config.max_model_len > 0:
217-
if len(token_ids) > self.config.max_model_len - 1:
218-
truncate_status = "response_truncated"
219-
self.logger.warning(
220-
f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}"
221-
)
222-
token_ids = token_ids[: self.config.max_model_len - 1]
223-
action_mask = action_mask[: self.config.max_model_len - 1]
224-
225-
temperature = temperature if temperature is not None else self.config.temperature
226-
logprobs = await self.logprobs(
227-
token_ids=token_ids.tolist(), temperature=temperature
228-
) # (seq_length - 1,)
229-
return Experience(
230-
tokens=token_ids,
231-
logprobs=logprobs[prompt_length - 1 :],
232-
prompt_length=prompt_length,
233-
action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
234-
messages=messages,
235-
truncate_status=truncate_status,
236-
)
237-
238146
async def prepare(self) -> None:
239147
"""Prepare the model before inference."""
240148
self.service_client = tinker.ServiceClient()

0 commit comments

Comments
 (0)