@@ -980,99 +980,84 @@ def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result_
980980
981981 # Create the Dataset Generaror:
982982 class SampleExpansionGenerator :
983- def __init__ (self , raw_text_samples , tokenizer , sample_expansion_batch_size = 50 , prompt_length_0 = PROMPT_LENGTH , max_seq_length = MAX_SEQ_LENGTH ):
983+ def __init__ (self , raw_text_samples , tokenizer , sample_expansion_batch_size = 50 , model_batch_size = 10 , prompt_length_0 = PROMPT_LENGTH , max_seq_length = MAX_SEQ_LENGTH , vocabulary_size = VOCABULARY_SIZE ):
984984 self .raw_text_samples = raw_text_samples
985985 self .tokenizer = tokenizer
986986 self .sample_expansion_batch_size = sample_expansion_batch_size
987+ self .model_batch_size = model_batch_size # Add this parameter
988+ self .prompt_length_0 = prompt_length_0
989+ self .max_seq_length = max_seq_length
990+ self .vocabulary_size = vocabulary_size
987991 self .data = []
988992 self .labels = []
989993 self .current_index = 0
990-
994+
991995 def _expand_next_batch (self ):
992-
993996 # Determine the next meta-batch
994997 start_idx = self .current_index
995998 end_idx = min (start_idx + self .sample_expansion_batch_size , len (self .raw_text_samples ))
996999 collect ()
997- # if start_idx >= end_idx:
998- # self.current_index = 0 # raise StopIteration("No more raw samples to process.")
999- # start_idx = 0
1000- # end_idx = min(self.sample_expansion_batch_size, len(self.raw_text_samples))
1001-
1000+
10021001 if start_idx >= end_idx :
10031002 raise StopIteration ("No more raw samples to process." )
1004-
1003+
10051004 batch_samples = self .raw_text_samples [start_idx :end_idx ]
10061005 self .current_index = end_idx
1007-
1008- # Run prepare_data on this batch
1009- input_ids_list , labels_list , _ = \
1010- prepare_data (
1011- data_0 = batch_samples ,
1012- tokenizer_0 = tokenizer ,
1013- max_seq_length = max_seq_length ,
1014- prompt_length = prompt_length_0 )
1015- # input_ids_list, labels_list, _ = prepare_data(batch_samples, max_seq_length=MAX_SEQ_LENGTH) # <<--<<
1016-
1006+
1007+ # Run prepare_data on this batch - use the instance parameters
1008+ input_ids_list , labels_list , _ = prepare_data (
1009+ data_0 = batch_samples ,
1010+ tokenizer_0 = self .tokenizer ,
1011+ max_seq_length = self .max_seq_length ,
1012+ prompt_length = self .prompt_length_0 )
1013+
10171014 # Assign to internal queues
10181015 self .data = input_ids_list
10191016 self .labels = labels_list
1020-
1021- # def __iter__(self):
1022- # return self
1023-
1024- # def __iter__(self):
1025- # # Create a fresh instance with the same parameters
1026- # return SampleExpansionGenerator(
1027- # self.raw_text_samples,
1028- # self.tokenizer,
1029- # self.sample_expansion_batch_size
1030- # )
1031-
1017+
10321018 def __iter__ (self ):
10331019 # Reset to initial state for new epoch
10341020 self .current_index = 0
10351021 self .data = []
10361022 self .labels = []
10371023 return self
1038-
1024+
10391025 def __next__ (self ):
10401026 # Check for mismatched state
10411027 if (len (self .data ) == 0 ) != (len (self .labels ) == 0 ):
10421028 raise ValueError ("Data and labels queues are out of sync." )
1043-
1029+
10441030 # If queues are empty, expand next batch
10451031 if len (self .data ) == 0 :
10461032 self ._expand_next_batch ()
1047-
1033+
10481034 # Pop and return one sample
1049- # input_sample = [self.data.pop(0)] # Nested as per model input spec
1050- # label_sample = [self.labels.pop(0)] # Nested as per model output spec
10511035 input_sample = self .data .pop (0 )
10521036 label_sample = self .labels .pop (0 )
1053-
1037+
10541038 return (input_sample , label_sample )
1055-
1056-
1039+
10571040 # Create the tf.data.Dataset
10581041 def create_dataset (raw_text_samples , tokenizer , sample_expansion_batch_size = 50 , model_batch_size = 10 ) -> tf .data .Dataset :
1059- generator = SampleExpansionGenerator (raw_text_samples , tokenizer , sample_expansion_batch_size )
1060-
1042+ generator_0 = SampleExpansionGenerator (
1043+ raw_text_samples = raw_text_samples ,
1044+ tokenizer = tokenizer ,
1045+ sample_expansion_batch_size = sample_expansion_batch_size ,
1046+ model_batch_size = model_batch_size # Pass this parameter
1047+ )
1048+
10611049 dataset = tf .data .Dataset .from_generator (
1062- lambda : generator ,
1050+ lambda : generator_0 ,
10631051 output_signature = (
1064- tf .TensorSpec (shape = (MAX_SEQ_LENGTH ,), dtype = tf .int32 ),
1065- tf .TensorSpec (shape = (VOCABULARY_SIZE ,), dtype = tf .float32 )
1066- # tf.TensorSpec(shape=(1, MAX_SEQ_LENGTH), dtype=tf.int32), # Nested input
1067- # tf.TensorSpec(shape=(1, VOCABULARY_SIZE), dtype=tf.float32) # Nested one-hot label
1052+ tf .TensorSpec (shape = (generator .max_seq_length ,), dtype = tf .int32 ), # Use generator's parameter
1053+ tf .TensorSpec (shape = (generator .vocabulary_size ,), dtype = tf .float32 ) # Use generator's parameter
10681054 )
10691055 )
1070- # Set dataset to allow multiple epochs:
1071- # dataset = dataset.repeat()
1056+
10721057 # Batch it
10731058 dataset = dataset .batch (model_batch_size )
10741059 return dataset
1075-
1060+
10761061 phase_i_b_train_dataset = \
10771062 create_dataset (
10781063 raw_text_samples = phase_i_b_train_samples ,
0 commit comments