Skip to content

refactor-generator-dataset-for-llm #267

@david-thrower

Description

@david-thrower

From #266 , merge target: #266

class SampleExpansionGenerator:
    def __init__(self, raw_text_samples, tokenizer, sample_expansion_batch_size=50, model_batch_size=10, prompt_length_0=PROMPT_LENGTH, max_seq_length=MAX_SEQ_LENGTH, vocabulary_size=VOCABULARY_SIZE):
        self.raw_text_samples = raw_text_samples
        self.tokenizer = tokenizer
        self.sample_expansion_batch_size = sample_expansion_batch_size
        self.model_batch_size = model_batch_size  # Add this parameter
        self.prompt_length_0 = prompt_length_0
        self.max_seq_length = max_seq_length
        self.vocabulary_size = vocabulary_size
        self.data = []
        self.labels = []
        self.current_index = 0

    def _expand_next_batch(self):
        # Determine the next meta-batch
        start_idx = self.current_index
        end_idx = min(start_idx + self.sample_expansion_batch_size, len(self.raw_text_samples))
        collect()
        
        if start_idx >= end_idx:
            raise StopIteration("No more raw samples to process.")
       
        batch_samples = self.raw_text_samples[start_idx:end_idx]
        self.current_index = end_idx

        # Run prepare_data on this batch - use the instance parameters
        input_ids_list, labels_list, _ = prepare_data(
            data_0=batch_samples,
            tokenizer_0=self.tokenizer,
            max_seq_length=self.max_seq_length,
            prompt_length=self.prompt_length_0)
        
        # Assign to internal queues
        self.data = input_ids_list
        self.labels = labels_list

    def __iter__(self):
        # Reset to initial state for new epoch
        self.current_index = 0
        self.data = []
        self.labels = []
        return self

    def __next__(self):
        # Check for mismatched state
        if (len(self.data) == 0) != (len(self.labels) == 0):
            raise ValueError("Data and labels queues are out of sync.")

        # If queues are empty, expand next batch
        if len(self.data) == 0:
            self._expand_next_batch()

        # Pop and return one sample
        input_sample = self.data.pop(0)
        label_sample = self.labels.pop(0)

        return (input_sample, label_sample)

# Create the tf.data.Dataset
def create_dataset(raw_text_samples, tokenizer, sample_expansion_batch_size=50, model_batch_size=10) -> tf.data.Dataset:
    generator = SampleExpansionGenerator(
        raw_text_samples=raw_text_samples, 
        tokenizer=tokenizer, 
        sample_expansion_batch_size=sample_expansion_batch_size,
        model_batch_size=model_batch_size  # Pass this parameter
    )

    dataset = tf.data.Dataset.from_generator(
        lambda: generator,
        output_signature=(
            tf.TensorSpec(shape=(generator.max_seq_length,), dtype=tf.int32),  # Use generator's parameter
            tf.TensorSpec(shape=(generator.vocabulary_size,), dtype=tf.float32)  # Use generator's parameter
        )
    )
    
    # Batch it
    dataset = dataset.batch(model_batch_size)
    return dataset

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions