@@ -352,14 +352,30 @@ def apply_chat_template_and_mask(
352352 tokenizer : PreTrainedTokenizer ,
353353 chat : List [Dict [str , str ]],
354354 max_length : Optional [int ] = None ,
355+ system_prompt : str = None ,
355356 padding : bool = True ,
356357 truncation : bool = True ,
357358 ignore_idx : int = - 100 ,
358359) -> Dict [str , torch .Tensor ]:
360+
361+ if system_prompt is None :
362+ system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n \n "
363+
364+ system_element = {
365+ "role" : "system" ,
366+ "content" : system_prompt ,
367+ }
368+
369+ # Format for RL.
370+ if "messages" in chat :
371+ gt_answer = chat .get ("gt_answer" , None )
372+ test_cases = chat .get ("test_cases" , None )
373+ chat = [chat ["messages" ]]
374+
359375 tokens = []
360376 assistant_mask = []
361377 for i , msg in enumerate (chat ):
362- msg_tokens = tokenizer .apply_chat_template ([msg ], tokenize = True )
378+ msg_tokens = tokenizer .apply_chat_template ([system_element , msg ], tokenize = True , add_generation_prompt = True )
363379 # remove unexpected bos token
364380 if i > 0 and msg_tokens [0 ] == tokenizer .bos_token_id :
365381 msg_tokens = msg_tokens [1 :]
@@ -372,14 +388,10 @@ def apply_chat_template_and_mask(
372388 if max_length is not None :
373389 if padding and len (tokens ) < max_length :
374390 to_pad = max_length - len (tokens )
375- if tokenizer .padding_side == "right" :
376- tokens .extend ([tokenizer .pad_token_id ] * to_pad )
377- assistant_mask .extend ([False ] * to_pad )
378- attention_mask .extend ([0 ] * to_pad )
379- else :
380- tokens = [tokenizer .pad_token_id ] * to_pad + tokens
381- assistant_mask = [False ] * to_pad + assistant_mask
382- attention_mask = [0 ] * to_pad + attention_mask
391+ # Left padding for generation.
392+ tokens = [tokenizer .pad_token_id ] * to_pad + tokens
393+ assistant_mask = [False ] * to_pad + assistant_mask
394+ attention_mask = [0 ] * to_pad + attention_mask
383395 if truncation and len (tokens ) > max_length :
384396 tokens = tokens [:max_length ]
385397 assistant_mask = assistant_mask [:max_length ]
@@ -389,6 +401,15 @@ def apply_chat_template_and_mask(
389401 labels = input_ids .clone ()
390402 labels [~ torch .tensor (assistant_mask , dtype = torch .bool )] = ignore_idx
391403
404+ if gt_answer is not None :
405+ return {"input_ids" : input_ids , "attention_mask" : attention_mask , "labels" : labels , "gt_answer" : gt_answer }
406+ elif test_cases is not None :
407+ return {
408+ "input_ids" : input_ids ,
409+ "attention_mask" : attention_mask ,
410+ "labels" : labels ,
411+ "test_cases" : test_cases ,
412+ }
392413 return {
393414 "input_ids" : input_ids ,
394415 "attention_mask" : attention_mask ,
@@ -402,21 +423,39 @@ class RawConversationDataset(Dataset):
402423 Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
403424 """
404425
405- def __init__ (self , tokenizer : PreTrainedTokenizer , input_file : str , max_length : int ) -> None :
426+ def __init__ (self , tokenizer : PreTrainedTokenizer , input_file : str , max_length : int , system_prompt : str ) -> None :
406427 self .tokenizer = tokenizer
407428 self .raw_texts = []
408429 with jsonlines .open (input_file ) as f :
409430 for line in f :
410431 self .raw_texts .append (line )
411432 self .tokenized_texts = [None ] * len (self .raw_texts )
412433 self .max_length = max_length
434+ self .system_prompt = system_prompt
413435
414436 def __len__ (self ) -> int :
415437 return len (self .raw_texts )
416438
417439 def __getitem__ (self , index : int ):
418440 if self .tokenized_texts [index ] is None :
419441 message = self .raw_texts [index ]
420- tokens = apply_chat_template_and_mask (self .tokenizer , message , self .max_length )
442+ tokens = apply_chat_template_and_mask (self .tokenizer , message , self .max_length , self . system_prompt )
421443 self .tokenized_texts [index ] = dict (tokens )
422444 return self .tokenized_texts [index ]
445+
446+
447+ def collate_fn_grpo (batch ):
448+ input_ids = [item ["input_ids" ] for item in batch ]
449+ attention_mask = [item ["attention_mask" ] for item in batch ]
450+ labels = [item ["labels" ] for item in batch ]
451+ # Assume input_ids, attention_mask, labels are already of the same length,
452+ # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
453+ input_ids = torch .stack (input_ids )
454+ attention_mask = torch .stack (attention_mask )
455+ labels = torch .stack (labels )
456+ ret = {"input_ids" : input_ids , "attention_mask" : attention_mask , "labels" : labels }
457+ if "test_cases" in batch [0 ]:
458+ ret ["test_cases" ] = [item ["test_cases" ] for item in batch ]
459+ if "gt_answer" in batch [0 ]:
460+ ret ["gt_answer" ] = [item ["gt_answer" ] for item in batch ]
461+ return ret
0 commit comments