@@ -980,7 +980,7 @@ def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result_
980980
981981 # Create the Dataset Generaror:
982982 class SampleExpansionGenerator :
983- def __init__ (self , raw_text_samples , tokenizer , sample_expansion_batch_size = 50 ):
983+ def __init__ (self , raw_text_samples , tokenizer , sample_expansion_batch_size = 50 , prompt_length_0 = PROMPT_LENGTH , max_seq_length = MAX_SEQ_LENGTH ):
984984 self .raw_text_samples = raw_text_samples
985985 self .tokenizer = tokenizer
986986 self .sample_expansion_batch_size = sample_expansion_batch_size
@@ -1006,7 +1006,13 @@ def _expand_next_batch(self):
10061006 self .current_index = end_idx
10071007
10081008 # Run prepare_data on this batch
1009- input_ids_list , labels_list , _ = prepare_data (batch_samples , max_seq_length = MAX_SEQ_LENGTH )
1009+ input_ids_list , labels_list , _ = \
1010+ prepare_data (
1011+ data_0 = batch_samples ,
1012+ tokenizer_0 = tokenizer ,
1013+ max_seq_length = max_seq_length ,
1014+ prompt_length = prompt_length_0 )
1015+ # input_ids_list, labels_list, _ = prepare_data(batch_samples, max_seq_length=MAX_SEQ_LENGTH) # <<--<<
10101016
10111017 # Assign to internal queues
10121018 self .data = input_ids_list
0 commit comments