Skip to content

Commit c7af97d

Browse files
ali-alshaar7Ali Alshaarawy
andauthored
add missing r1 prompt style (#1929)
Co-authored-by: Ali Alshaarawy <[email protected]>
1 parent 1d93671 commit c7af97d

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

litgpt/prompts.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,55 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
215215
[tokenizer.token_to_id("<|eot_id|>")],
216216
)
217217

218+
class R1Base(PromptStyle):
219+
def apply(self, prompt: Union[str, List[Dict[str, str]]], **kwargs: str) -> str:
220+
default_system_prompt = ""
221+
222+
bos_token = "<|begin▁of▁sentence|>"
223+
eos_token = ""
224+
225+
if isinstance(prompt, str):
226+
return (
227+
f"{default_system_prompt}"
228+
f"<|User|>{prompt}"
229+
f"<|Assistant|>" # Prepares for assistant response
230+
)
231+
elif isinstance(prompt, list):
232+
233+
def encode_message(message: Dict[str, str]) -> str:
234+
role = message["role"]
235+
content = message["content"].strip()
236+
237+
if role == "system":
238+
return content # System prompt is prepended at the start
239+
elif role == "user":
240+
return f"<|User|>{content}"
241+
elif role == "assistant":
242+
return f"<|Assistant|>{content}{eos_token}"
243+
else:
244+
raise ValueError(f"Unknown role: '{role}'. Supported roles are 'assistant', 'user', and 'system'.")
245+
246+
# Extract system prompt (if any)
247+
system_prompt = ""
248+
if prompt[0].get("role") == "system":
249+
system_prompt = prompt[0]["content"]
250+
prompt = prompt[1:] # Remove system message from the list
251+
252+
# Construct the formatted prompt
253+
formatted_prompt = system_prompt
254+
for message in prompt:
255+
formatted_prompt += encode_message(message)
256+
257+
formatted_prompt += "<|Assistant|>" # Prepares for assistant response
258+
return formatted_prompt
259+
else:
260+
raise ValueError(f"Unsupported prompt type: {type(prompt)}")
261+
262+
def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
263+
return (
264+
[tokenizer.eos_id],
265+
[tokenizer.token_to_id("<|end▁of▁sentence|>")],
266+
)
218267

219268
class FreeWilly2(PromptStyle):
220269
def apply(self, prompt: str, **kwargs: str) -> str:
@@ -372,6 +421,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
372421
return Llama3()
373422
if re.search("Llama-3.*-Instruct-*", model_name):
374423
return Llama3()
424+
if re.search("R1", model_name):
425+
return R1Base()
375426
if re.search("FreeWilly2", model_name):
376427
return FreeWilly2()
377428
if re.search("Platypus", model_name):

0 commit comments

Comments
 (0)