Skip to content

Commit 18a774b

Browse files
Update generative-proof-of-concept-CPU-preprocessing-in-memory.py
Refactor generator and dataset.
1 parent d78b666 commit 18a774b

File tree

1 file changed

+34
-49
lines changed

1 file changed

+34
-49
lines changed

generative-proof-of-concept-CPU-preprocessing-in-memory.py

Lines changed: 34 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -980,99 +980,84 @@ def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result_
980980

981981
# Create the Dataset Generaror:
982982
class SampleExpansionGenerator:
983-
def __init__(self, raw_text_samples, tokenizer, sample_expansion_batch_size=50, prompt_length_0=PROMPT_LENGTH, max_seq_length=MAX_SEQ_LENGTH):
983+
def __init__(self, raw_text_samples, tokenizer, sample_expansion_batch_size=50, model_batch_size=10, prompt_length_0=PROMPT_LENGTH, max_seq_length=MAX_SEQ_LENGTH, vocabulary_size=VOCABULARY_SIZE):
984984
self.raw_text_samples = raw_text_samples
985985
self.tokenizer = tokenizer
986986
self.sample_expansion_batch_size = sample_expansion_batch_size
987+
self.model_batch_size = model_batch_size # Add this parameter
988+
self.prompt_length_0 = prompt_length_0
989+
self.max_seq_length = max_seq_length
990+
self.vocabulary_size = vocabulary_size
987991
self.data = []
988992
self.labels = []
989993
self.current_index = 0
990-
994+
991995
def _expand_next_batch(self):
992-
993996
# Determine the next meta-batch
994997
start_idx = self.current_index
995998
end_idx = min(start_idx + self.sample_expansion_batch_size, len(self.raw_text_samples))
996999
collect()
997-
# if start_idx >= end_idx:
998-
# self.current_index = 0 # raise StopIteration("No more raw samples to process.")
999-
# start_idx = 0
1000-
# end_idx = min(self.sample_expansion_batch_size, len(self.raw_text_samples))
1001-
1000+
10021001
if start_idx >= end_idx:
10031002
raise StopIteration("No more raw samples to process.")
1004-
1003+
10051004
batch_samples = self.raw_text_samples[start_idx:end_idx]
10061005
self.current_index = end_idx
1007-
1008-
# Run prepare_data on this batch
1009-
input_ids_list, labels_list, _ =\
1010-
prepare_data(
1011-
data_0=batch_samples,
1012-
tokenizer_0=tokenizer,
1013-
max_seq_length=max_seq_length,
1014-
prompt_length=prompt_length_0)
1015-
# input_ids_list, labels_list, _ = prepare_data(batch_samples, max_seq_length=MAX_SEQ_LENGTH) # <<--<<
1016-
1006+
1007+
# Run prepare_data on this batch - use the instance parameters
1008+
input_ids_list, labels_list, _ = prepare_data(
1009+
data_0=batch_samples,
1010+
tokenizer_0=self.tokenizer,
1011+
max_seq_length=self.max_seq_length,
1012+
prompt_length=self.prompt_length_0)
1013+
10171014
# Assign to internal queues
10181015
self.data = input_ids_list
10191016
self.labels = labels_list
1020-
1021-
# def __iter__(self):
1022-
# return self
1023-
1024-
# def __iter__(self):
1025-
# # Create a fresh instance with the same parameters
1026-
# return SampleExpansionGenerator(
1027-
# self.raw_text_samples,
1028-
# self.tokenizer,
1029-
# self.sample_expansion_batch_size
1030-
# )
1031-
1017+
10321018
def __iter__(self):
10331019
# Reset to initial state for new epoch
10341020
self.current_index = 0
10351021
self.data = []
10361022
self.labels = []
10371023
return self
1038-
1024+
10391025
def __next__(self):
10401026
# Check for mismatched state
10411027
if (len(self.data) == 0) != (len(self.labels) == 0):
10421028
raise ValueError("Data and labels queues are out of sync.")
1043-
1029+
10441030
# If queues are empty, expand next batch
10451031
if len(self.data) == 0:
10461032
self._expand_next_batch()
1047-
1033+
10481034
# Pop and return one sample
1049-
# input_sample = [self.data.pop(0)] # Nested as per model input spec
1050-
# label_sample = [self.labels.pop(0)] # Nested as per model output spec
10511035
input_sample = self.data.pop(0)
10521036
label_sample = self.labels.pop(0)
1053-
1037+
10541038
return (input_sample, label_sample)
1055-
1056-
1039+
10571040
# Create the tf.data.Dataset
10581041
def create_dataset(raw_text_samples, tokenizer, sample_expansion_batch_size=50, model_batch_size=10) -> tf.data.Dataset:
1059-
generator = SampleExpansionGenerator(raw_text_samples, tokenizer, sample_expansion_batch_size)
1060-
1042+
generator_0 = SampleExpansionGenerator(
1043+
raw_text_samples=raw_text_samples,
1044+
tokenizer=tokenizer,
1045+
sample_expansion_batch_size=sample_expansion_batch_size,
1046+
model_batch_size=model_batch_size # Pass this parameter
1047+
)
1048+
10611049
dataset = tf.data.Dataset.from_generator(
1062-
lambda: generator,
1050+
lambda: generator_0,
10631051
output_signature=(
1064-
tf.TensorSpec(shape=(MAX_SEQ_LENGTH,), dtype=tf.int32),
1065-
tf.TensorSpec(shape=(VOCABULARY_SIZE,), dtype=tf.float32)
1066-
# tf.TensorSpec(shape=(1, MAX_SEQ_LENGTH), dtype=tf.int32), # Nested input
1067-
# tf.TensorSpec(shape=(1, VOCABULARY_SIZE), dtype=tf.float32) # Nested one-hot label
1052+
tf.TensorSpec(shape=(generator.max_seq_length,), dtype=tf.int32), # Use generator's parameter
1053+
tf.TensorSpec(shape=(generator.vocabulary_size,), dtype=tf.float32) # Use generator's parameter
10681054
)
10691055
)
1070-
# Set dataset to allow multiple epochs:
1071-
# dataset = dataset.repeat()
1056+
10721057
# Batch it
10731058
dataset = dataset.batch(model_batch_size)
10741059
return dataset
1075-
1060+
10761061
phase_i_b_train_dataset =\
10771062
create_dataset(
10781063
raw_text_samples=phase_i_b_train_samples,

0 commit comments

Comments
 (0)