@@ -69,7 +69,8 @@ def create_dataset(**dataset_config):
69
69
task_dataset_path = task_dataset_path ,
70
70
task_dataset_prob = task_dataset_prob ,
71
71
sub_dataset_type = sub_dataset_type ,
72
- process_fn = (process_fc if dataset_config ["sub_dataset_type" ] == "chatml" else process_example ),
72
+ process_fn = process_example ,
73
+ process_fn_fc = process_fc ,
73
74
)
74
75
sequence_dataset = SequenceDataset (
75
76
dataset = example_dataset ,
@@ -174,7 +175,7 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, model_args, max_seq_len:
174
175
175
176
def process_fc (data , input_file ):
176
177
multi_turns_messages = data ["messages" ]
177
- tools_list = data ["tools" ]
178
+ tools_list = data ["tools" ] if "tools" in data else None
178
179
label = data ["label" ] if "label" in data else None
179
180
180
181
system = ""
@@ -507,17 +508,26 @@ def __iter__(self):
507
508
508
509
def function_call_chat_template (self , messages , tools ):
509
510
history = messages [:- 1 ]
511
+ input_dict = dict ()
512
+ input_dict ["messages" ] = history
513
+ if tools is not None :
514
+ input_dict ["tools" ] = tools
510
515
history_str = self .tokenizer .apply_chat_template (
511
- { "messages" : history , "tools" : tools } ,
516
+ input_dict ,
512
517
add_generation_prompt = True ,
513
518
tokenize = False ,
514
519
)
515
520
history_len = len (history_str )
521
+ input_dict ["messages" ] = messages
516
522
all_str = self .tokenizer .apply_chat_template (
517
- { "messages" : messages , "tools" : tools } ,
523
+ input_dict ,
518
524
add_generation_prompt = False ,
519
525
tokenize = False ,
520
526
)
527
+ # (21b think model) remove generation content
528
+ s = "<|im_end|>\n \n <|im_start|>assistant\n <think>\n "
529
+ if all_str .endswith (s ):
530
+ all_str = all_str [: - len (s )]
521
531
response_str = all_str [history_len :]
522
532
history_id = self .tokenizer .convert_tokens_to_ids (self .tokenizer .tokenize (history_str ))
523
533
response_id = self .tokenizer .convert_tokens_to_ids (self .tokenizer .tokenize (response_str ))
@@ -591,7 +601,7 @@ def _postprocess_sequence(self, example, actual_example_num):
591
601
if LOGGER_COUNT <= 5 :
592
602
logger .warning (f"even one turn, example_output:'{{'src':[{ sub_src } , ……],'tgt':[……{ sub_tgt } ]}}'" )
593
603
except Exception :
594
- logger .warning (f "[SKIP] wrong example: { example } " )
604
+ logger .warning ("[SKIP] wrong example" )
595
605
596
606
return None
597
607
0 commit comments