@@ -83,6 +83,7 @@ def create_dataset(**dataset_config):
83
83
packing = dataset_config ["packing" ],
84
84
mix_strategy = dataset_config ["mix_strategy" ],
85
85
encode_one_turn = dataset_config ["encode_one_turn" ],
86
+ use_template = dataset_config ["use_template" ],
86
87
)
87
88
return sequence_dataset
88
89
@@ -289,6 +290,7 @@ def __init__(
289
290
packing : bool = False ,
290
291
mix_strategy : str = "random" ,
291
292
encode_one_turn : bool = True ,
293
+ use_template : bool = True ,
292
294
):
293
295
"""Initialize SequenceDataset.
294
296
@@ -319,6 +321,7 @@ def __init__(
319
321
self .packing = packing
320
322
self .mix_strategy = mix_strategy
321
323
self .encode_one_turn = encode_one_turn
324
+ self .use_template = use_template
322
325
self .num_samples_each_epoch = num_samples_each_epoch
323
326
self .reverse = True
324
327
@@ -536,12 +539,19 @@ def _postprocess_sequence(self, example, actual_example_num):
536
539
Returns:
537
540
Sequence: Processed sequence or None if invalid.
538
541
"""
539
- if not self .tokenizer .chat_template :
540
- self .tokenizer .chat_template = NONE_CHAT_TEMPLATE
541
- if example .is_function_call :
542
- encoded_messages = self ._postprocess_fc_sequence (example )
542
+ if self .use_template :
543
+ if not self .tokenizer .chat_template :
544
+ self .tokenizer .chat_template = NONE_CHAT_TEMPLATE
545
+ if example .is_function_call :
546
+ encoded_messages = self ._postprocess_fc_sequence (example )
547
+ else :
548
+ encoded_messages = self .tokenizer .encode_chat_inputs (
549
+ example .request , encode_one_turn = self .encode_one_turn
550
+ )
543
551
else :
544
- encoded_messages = self .tokenizer .encode_chat_inputs (example .request , encode_one_turn = self .encode_one_turn )
552
+ encoded_messages = self .tokenizer .encode_chat_inputs_with_no_template (
553
+ example .request , encode_one_turn = self .encode_one_turn
554
+ )
545
555
546
556
num_reserved_tokens_for_each_dialog = 1 # only break_turn_token or end_token
547
557
num_reserved_tokens_for_each_turn = 8
@@ -585,26 +595,36 @@ def _postprocess_sequence(self, example, actual_example_num):
585
595
586
596
return None
587
597
588
- if self .begin_token_id is not None and self .end_of_response_id is not None :
589
- # Maybe left truncated, so need to add begin_token
590
- if tokens [0 ] != self .begin_token_id :
591
- tokens = [self .begin_token_id ] + tokens
592
- loss_mask = [0 ] + loss_mask
593
-
594
- if len (tokens ) > self .max_seq_len :
595
- raise RuntimeError (f"token_ids is too long: { len (tokens )} " )
596
-
597
- # Add EOS token at the end
598
- del tokens [- 1 ]
599
- del loss_mask [- 1 ]
600
- labels = tokens [1 :] + [self .tokenizer .eos_token_id ]
601
-
602
- # end_of_response is a special token that indicates the end of the turn.
603
- # end_token is a special token that indicates the end of the answer.
604
- labels = [label if label != self .end_of_response_id else self .tokenizer .eos_token_id for label in labels ]
598
+ if self .use_template :
599
+ if self .begin_token_id is not None and self .end_of_response_id is not None :
600
+ # Maybe left truncated, so need to add begin_token
601
+ if tokens [0 ] != self .begin_token_id :
602
+ tokens = [self .begin_token_id ] + tokens
603
+ loss_mask = [0 ] + loss_mask
604
+
605
+ if len (tokens ) > self .max_seq_len :
606
+ raise RuntimeError (f"token_ids is too long: { len (tokens )} " )
607
+
608
+ # Add EOS token at the end
609
+ del tokens [- 1 ]
610
+ del loss_mask [- 1 ]
611
+ labels = tokens [1 :] + [self .tokenizer .eos_token_id ]
612
+
613
+ # end_of_response is a special token that indicates the end of the turn.
614
+ # end_token is a special token that indicates the end of the answer.
615
+ labels = [
616
+ label if label != self .end_of_response_id else self .tokenizer .eos_token_id for label in labels
617
+ ]
618
+ else :
619
+ tokens = tokens [:- 1 ] + [self .tokenizer .eos_token_id ]
620
+ labels = tokens [1 :] + [- 100 ]
621
+ if len (tokens ) > self .max_seq_len :
622
+ raise RuntimeError (f"token_ids is too long: { len (tokens )} " )
605
623
else :
606
- tokens = tokens [:- 1 ] + [self .tokenizer .eos_token_id ]
607
- labels = tokens [1 :] + [- 100 ]
624
+ oral_tokens = tokens
625
+ tokens = oral_tokens [:- 1 ]
626
+ labels = oral_tokens [1 :]
627
+ loss_mask = loss_mask [1 :]
608
628
if len (tokens ) > self .max_seq_len :
609
629
raise RuntimeError (f"token_ids is too long: { len (tokens )} " )
610
630
0 commit comments