@@ -215,6 +215,55 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
215
215
[tokenizer .token_to_id ("<|eot_id|>" )],
216
216
)
217
217
218
+ class R1Base (PromptStyle ):
219
+ def apply (self , prompt : Union [str , List [Dict [str , str ]]], ** kwargs : str ) -> str :
220
+ default_system_prompt = ""
221
+
222
+ bos_token = "<|begin▁of▁sentence|>"
223
+ eos_token = ""
224
+
225
+ if isinstance (prompt , str ):
226
+ return (
227
+ f"{ default_system_prompt } "
228
+ f"<|User|>{ prompt } "
229
+ f"<|Assistant|>" # Prepares for assistant response
230
+ )
231
+ elif isinstance (prompt , list ):
232
+
233
+ def encode_message (message : Dict [str , str ]) -> str :
234
+ role = message ["role" ]
235
+ content = message ["content" ].strip ()
236
+
237
+ if role == "system" :
238
+ return content # System prompt is prepended at the start
239
+ elif role == "user" :
240
+ return f"<|User|>{ content } "
241
+ elif role == "assistant" :
242
+ return f"<|Assistant|>{ content } { eos_token } "
243
+ else :
244
+ raise ValueError (f"Unknown role: '{ role } '. Supported roles are 'assistant', 'user', and 'system'." )
245
+
246
+ # Extract system prompt (if any)
247
+ system_prompt = ""
248
+ if prompt [0 ].get ("role" ) == "system" :
249
+ system_prompt = prompt [0 ]["content" ]
250
+ prompt = prompt [1 :] # Remove system message from the list
251
+
252
+ # Construct the formatted prompt
253
+ formatted_prompt = system_prompt
254
+ for message in prompt :
255
+ formatted_prompt += encode_message (message )
256
+
257
+ formatted_prompt += "<|Assistant|>" # Prepares for assistant response
258
+ return formatted_prompt
259
+ else :
260
+ raise ValueError (f"Unsupported prompt type: { type (prompt )} " )
261
+
262
+ def stop_tokens (self , tokenizer : "Tokenizer" ) -> Tuple [List [int ], ...]:
263
+ return (
264
+ [tokenizer .eos_id ],
265
+ [tokenizer .token_to_id ("<|end▁of▁sentence|>" )],
266
+ )
218
267
219
268
class FreeWilly2 (PromptStyle ):
220
269
def apply (self , prompt : str , ** kwargs : str ) -> str :
@@ -372,6 +421,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
372
421
return Llama3 ()
373
422
if re .search ("Llama-3.*-Instruct-*" , model_name ):
374
423
return Llama3 ()
424
+ if re .search ("R1" , model_name ):
425
+ return R1Base ()
375
426
if re .search ("FreeWilly2" , model_name ):
376
427
return FreeWilly2 ()
377
428
if re .search ("Platypus" , model_name ):
0 commit comments