@@ -352,14 +352,30 @@ def apply_chat_template_and_mask(
352
352
tokenizer : PreTrainedTokenizer ,
353
353
chat : List [Dict [str , str ]],
354
354
max_length : Optional [int ] = None ,
355
+ system_prompt : str = None ,
355
356
padding : bool = True ,
356
357
truncation : bool = True ,
357
358
ignore_idx : int = - 100 ,
358
359
) -> Dict [str , torch .Tensor ]:
360
+
361
+ if system_prompt is None :
362
+ system_prompt = "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, i.e., <answer> 123 </answer>.\n \n "
363
+
364
+ system_element = {
365
+ "role" : "system" ,
366
+ "content" : system_prompt ,
367
+ }
368
+
369
+ # Format for RL.
370
+ if "messages" in chat :
371
+ gt_answer = chat .get ("gt_answer" , None )
372
+ test_cases = chat .get ("test_cases" , None )
373
+ chat = [chat ["messages" ]]
374
+
359
375
tokens = []
360
376
assistant_mask = []
361
377
for i , msg in enumerate (chat ):
362
- msg_tokens = tokenizer .apply_chat_template ([msg ], tokenize = True )
378
+ msg_tokens = tokenizer .apply_chat_template ([system_element , msg ], tokenize = True , add_generation_prompt = True )
363
379
# remove unexpected bos token
364
380
if i > 0 and msg_tokens [0 ] == tokenizer .bos_token_id :
365
381
msg_tokens = msg_tokens [1 :]
@@ -372,14 +388,10 @@ def apply_chat_template_and_mask(
372
388
if max_length is not None :
373
389
if padding and len (tokens ) < max_length :
374
390
to_pad = max_length - len (tokens )
375
- if tokenizer .padding_side == "right" :
376
- tokens .extend ([tokenizer .pad_token_id ] * to_pad )
377
- assistant_mask .extend ([False ] * to_pad )
378
- attention_mask .extend ([0 ] * to_pad )
379
- else :
380
- tokens = [tokenizer .pad_token_id ] * to_pad + tokens
381
- assistant_mask = [False ] * to_pad + assistant_mask
382
- attention_mask = [0 ] * to_pad + attention_mask
391
+ # Left padding for generation.
392
+ tokens = [tokenizer .pad_token_id ] * to_pad + tokens
393
+ assistant_mask = [False ] * to_pad + assistant_mask
394
+ attention_mask = [0 ] * to_pad + attention_mask
383
395
if truncation and len (tokens ) > max_length :
384
396
tokens = tokens [:max_length ]
385
397
assistant_mask = assistant_mask [:max_length ]
@@ -389,6 +401,15 @@ def apply_chat_template_and_mask(
389
401
labels = input_ids .clone ()
390
402
labels [~ torch .tensor (assistant_mask , dtype = torch .bool )] = ignore_idx
391
403
404
+ if gt_answer is not None :
405
+ return {"input_ids" : input_ids , "attention_mask" : attention_mask , "labels" : labels , "gt_answer" : gt_answer }
406
+ elif test_cases is not None :
407
+ return {
408
+ "input_ids" : input_ids ,
409
+ "attention_mask" : attention_mask ,
410
+ "labels" : labels ,
411
+ "test_cases" : test_cases ,
412
+ }
392
413
return {
393
414
"input_ids" : input_ids ,
394
415
"attention_mask" : attention_mask ,
@@ -402,21 +423,39 @@ class RawConversationDataset(Dataset):
402
423
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
403
424
"""
404
425
405
- def __init__ (self , tokenizer : PreTrainedTokenizer , input_file : str , max_length : int ) -> None :
426
+ def __init__ (self , tokenizer : PreTrainedTokenizer , input_file : str , max_length : int , system_prompt : str ) -> None :
406
427
self .tokenizer = tokenizer
407
428
self .raw_texts = []
408
429
with jsonlines .open (input_file ) as f :
409
430
for line in f :
410
431
self .raw_texts .append (line )
411
432
self .tokenized_texts = [None ] * len (self .raw_texts )
412
433
self .max_length = max_length
434
+ self .system_prompt = system_prompt
413
435
414
436
def __len__ (self ) -> int :
415
437
return len (self .raw_texts )
416
438
417
439
def __getitem__ (self , index : int ):
418
440
if self .tokenized_texts [index ] is None :
419
441
message = self .raw_texts [index ]
420
- tokens = apply_chat_template_and_mask (self .tokenizer , message , self .max_length )
442
+ tokens = apply_chat_template_and_mask (self .tokenizer , message , self .max_length , self . system_prompt )
421
443
self .tokenized_texts [index ] = dict (tokens )
422
444
return self .tokenized_texts [index ]
445
+
446
+
447
+ def collate_fn_grpo (batch ):
448
+ input_ids = [item ["input_ids" ] for item in batch ]
449
+ attention_mask = [item ["attention_mask" ] for item in batch ]
450
+ labels = [item ["labels" ] for item in batch ]
451
+ # Assume input_ids, attention_mask, labels are already of the same length,
452
+ # otherwise use pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
453
+ input_ids = torch .stack (input_ids )
454
+ attention_mask = torch .stack (attention_mask )
455
+ labels = torch .stack (labels )
456
+ ret = {"input_ids" : input_ids , "attention_mask" : attention_mask , "labels" : labels }
457
+ if "test_cases" in batch [0 ]:
458
+ ret ["test_cases" ] = [item ["test_cases" ] for item in batch ]
459
+ if "gt_answer" in batch [0 ]:
460
+ ret ["gt_answer" ] = [item ["gt_answer" ] for item in batch ]
461
+ return ret
0 commit comments