-
Notifications
You must be signed in to change notification settings - Fork 6
Open
Labels
status/ready-pending-testsReady to make pull request once tests pass.Ready to make pull request once tests pass.
Description
From #266 , merge target: #266
class SampleExpansionGenerator:
def __init__(self, raw_text_samples, tokenizer, sample_expansion_batch_size=50, model_batch_size=10, prompt_length_0=PROMPT_LENGTH, max_seq_length=MAX_SEQ_LENGTH, vocabulary_size=VOCABULARY_SIZE):
self.raw_text_samples = raw_text_samples
self.tokenizer = tokenizer
self.sample_expansion_batch_size = sample_expansion_batch_size
self.model_batch_size = model_batch_size # Add this parameter
self.prompt_length_0 = prompt_length_0
self.max_seq_length = max_seq_length
self.vocabulary_size = vocabulary_size
self.data = []
self.labels = []
self.current_index = 0
def _expand_next_batch(self):
# Determine the next meta-batch
start_idx = self.current_index
end_idx = min(start_idx + self.sample_expansion_batch_size, len(self.raw_text_samples))
collect()
if start_idx >= end_idx:
raise StopIteration("No more raw samples to process.")
batch_samples = self.raw_text_samples[start_idx:end_idx]
self.current_index = end_idx
# Run prepare_data on this batch - use the instance parameters
input_ids_list, labels_list, _ = prepare_data(
data_0=batch_samples,
tokenizer_0=self.tokenizer,
max_seq_length=self.max_seq_length,
prompt_length=self.prompt_length_0)
# Assign to internal queues
self.data = input_ids_list
self.labels = labels_list
def __iter__(self):
# Reset to initial state for new epoch
self.current_index = 0
self.data = []
self.labels = []
return self
def __next__(self):
# Check for mismatched state
if (len(self.data) == 0) != (len(self.labels) == 0):
raise ValueError("Data and labels queues are out of sync.")
# If queues are empty, expand next batch
if len(self.data) == 0:
self._expand_next_batch()
# Pop and return one sample
input_sample = self.data.pop(0)
label_sample = self.labels.pop(0)
return (input_sample, label_sample)
# Create the tf.data.Dataset
def create_dataset(raw_text_samples, tokenizer, sample_expansion_batch_size=50, model_batch_size=10) -> tf.data.Dataset:
generator = SampleExpansionGenerator(
raw_text_samples=raw_text_samples,
tokenizer=tokenizer,
sample_expansion_batch_size=sample_expansion_batch_size,
model_batch_size=model_batch_size # Pass this parameter
)
dataset = tf.data.Dataset.from_generator(
lambda: generator,
output_signature=(
tf.TensorSpec(shape=(generator.max_seq_length,), dtype=tf.int32), # Use generator's parameter
tf.TensorSpec(shape=(generator.vocabulary_size,), dtype=tf.float32) # Use generator's parameter
)
)
# Batch it
dataset = dataset.batch(model_batch_size)
return dataset
Metadata
Metadata
Assignees
Labels
status/ready-pending-testsReady to make pull request once tests pass.Ready to make pull request once tests pass.