1010
1111from trinity .common .config import InferenceModelConfig
1212from trinity .common .experience import Experience
13- from trinity .common .models .model import InferenceModel
14- from trinity .common .models .utils import get_action_mask_method
13+ from trinity .common .models .model import BaseInferenceModel
1514from trinity .manager .synchronizer import Synchronizer
1615
1716
18- class TinkerModel (InferenceModel ):
17+ class TinkerModel (BaseInferenceModel ):
1918 def __init__ (
2019 self ,
2120 config : InferenceModelConfig ,
@@ -25,12 +24,6 @@ def __init__(
2524 self .synchronizer = Synchronizer .get_actor (namespace = ray .get_runtime_context ().namespace )
2625 self .model = None
2726 self .model_path = config .model_path
28- self .tokenizer = None
29- self .chat_template = None
30- if self .config .chat_template :
31- self .chat_template = self .config .chat_template
32- self .action_mask_method = get_action_mask_method (self .chat_template )
33- self .enable_thinking = config .enable_thinking
3427
3528 async def _initialize_tokenizer (self ) -> None :
3629 """Initialize the tokenizer."""
@@ -62,34 +55,9 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
6255 if self .tokenizer is None :
6356 await self ._initialize_tokenizer ()
6457
65- # Tokenize once without truncation to check if truncation is needed
66- token_ids = self .tokenizer ( # type: ignore
67- prompt ,
68- truncation = False ,
69- return_tensors = "pt" ,
70- )[
71- "input_ids"
72- ][0 ].tolist ()
73-
74- # Check if truncation is needed and apply it
75- if self .config .enable_prompt_truncation and self .config .max_prompt_tokens is not None :
76- if len (token_ids ) > self .config .max_prompt_tokens :
77- self .logger .warning (
78- f"Prompt was truncated to { self .config .max_prompt_tokens } tokens"
79- )
80- token_ids = token_ids [: self .config .max_prompt_tokens + 1 ] # leave one for response
81- return [
82- Experience (
83- tokens = token_ids ,
84- logprobs = torch .zeros (1 , dtype = torch .float32 ),
85- prompt_length = len (token_ids ) - 1 ,
86- prompt_text = self .tokenizer .decode (token_ids [:- 1 ]),
87- response_text = self .tokenizer .decode (token_ids [- 1 ]),
88- truncate_status = "prompt_truncated" ,
89- reward = 0.0 ,
90- )
91- for _ in range (kwargs .get ("n" , 1 ))
92- ]
58+ token_ids , is_valid = self ._handle_prompt_truncation (prompt , ** kwargs )
59+ if not is_valid :
60+ return token_ids
9361
9462 with_chat_completion = kwargs .get ("with_chat_completion" , False )
9563 if with_chat_completion :
@@ -157,8 +125,6 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
157125 """Generate experiences from a list of history chat messages in async."""
158126 if self .tokenizer is None :
159127 await self ._initialize_tokenizer ()
160- if self .chat_template is None :
161- self .chat_template = self .tokenizer .get_chat_template ()
162128
163129 # TODO: this is a hack to support openai chat messages, which only supports text
164130 for msg in messages :
@@ -169,72 +135,14 @@ async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
169135 content_str = msg ["content" ]
170136 msg ["content" ] = content_str
171137
172- if messages [- 1 ]["role" ] == "assistant" :
173- prompt = self .tokenizer .apply_chat_template (
174- messages ,
175- tokenize = False ,
176- continue_final_message = True ,
177- chat_template = self .chat_template ,
178- )
179- else :
180- prompt = self .tokenizer .apply_chat_template (
181- messages ,
182- tokenize = False ,
183- add_generation_prompt = True ,
184- chat_template = self .chat_template ,
185- enable_thinking = self .enable_thinking ,
186- )
138+ prompt = self .apply_chat_template (self .tokenizer , messages )
187139 return await self .generate (prompt = prompt , ** kwargs )
188140
189141 async def logprobs (self , token_ids : List [int ], ** kwargs ) -> Tensor :
190142 """Generate logprobs for a list of tokens in async."""
191143 logprobs = await self .model .compute_logprobs_async (types .ModelInput .from_ints (token_ids ))
192144 return torch .tensor (logprobs [1 :], dtype = torch .float32 )
193145
194- async def convert_messages_to_experience (
195- self ,
196- messages : List [dict ],
197- tools : Optional [List [dict ]] = None ,
198- temperature : Optional [float ] = None ,
199- ) -> Experience :
200- """Convert a list of messages into an experience in async."""
201- if self .tokenizer is None :
202- await self ._initialize_tokenizer ()
203- if self .chat_template is None :
204- self .chat_template = self .tokenizer .get_chat_template ()
205- token_ids , action_mask , prompt_length = self .action_mask_method (
206- tokenizer = self .tokenizer ,
207- messages = messages ,
208- tools = tools ,
209- chat_template = self .chat_template ,
210- enable_thinking = self .enable_thinking ,
211- ) # (seq_length, ), (seq_length, )
212-
213- # Truncate tokens if they exceed the length limit
214- assert token_ids is not None
215- truncate_status = None
216- if self .config .max_model_len is not None and self .config .max_model_len > 0 :
217- if len (token_ids ) > self .config .max_model_len - 1 :
218- truncate_status = "response_truncated"
219- self .logger .warning (
220- f"Warning: { len (token_ids )= } exceeds the length limit { (self .config .max_model_len - 1 )= } "
221- )
222- token_ids = token_ids [: self .config .max_model_len - 1 ]
223- action_mask = action_mask [: self .config .max_model_len - 1 ]
224-
225- temperature = temperature if temperature is not None else self .config .temperature
226- logprobs = await self .logprobs (
227- token_ids = token_ids .tolist (), temperature = temperature
228- ) # (seq_length - 1,)
229- return Experience (
230- tokens = token_ids ,
231- logprobs = logprobs [prompt_length - 1 :],
232- prompt_length = prompt_length ,
233- action_mask = action_mask [prompt_length :], # Exclude the prompt tokens
234- messages = messages ,
235- truncate_status = truncate_status ,
236- )
237-
238146 async def prepare (self ) -> None :
239147 """Prepare the model before inference."""
240148 self .service_client = tinker .ServiceClient ()
0 commit comments