@@ -1331,7 +1331,7 @@ def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result_
13311331
13321332 # Create the Dataset Generaror:
13331333 class SampleExpansionGenerator :
1334- def __init__ (self , raw_text_samples , tokenizer , sample_expansion_batch_size = 5 ):
1334+ def __init__ (self , raw_text_samples , tokenizer , sample_expansion_batch_size = 50 ):
13351335 self .raw_text_samples = raw_text_samples
13361336 self .tokenizer = tokenizer
13371337 self .sample_expansion_batch_size = sample_expansion_batch_size
@@ -1400,7 +1400,7 @@ def __next__(self):
14001400
14011401
14021402 # Create the tf.data.Dataset
1403- def create_dataset (raw_text_samples , tokenizer , sample_expansion_batch_size = 10 ) -> tf .data .Dataset :
1403+ def create_dataset (raw_text_samples , tokenizer , sample_expansion_batch_size = 50 , model_batch_size = 10 ) -> tf .data .Dataset :
14041404 generator = SampleExpansionGenerator (raw_text_samples , tokenizer , sample_expansion_batch_size )
14051405
14061406 dataset = tf .data .Dataset .from_generator (
@@ -1415,21 +1415,23 @@ def create_dataset(raw_text_samples, tokenizer, sample_expansion_batch_size=10)
14151415 # Set dataset to allow multiple epochs:
14161416 # dataset = dataset.repeat()
14171417 # Batch it
1418- dataset = dataset .batch (batch_size )
1418+ dataset = dataset .batch (model_batch_size )
14191419 return dataset
14201420
14211421 phase_i_b_train_dataset = \
14221422 create_dataset (
14231423 raw_text_samples = phase_i_b_train_samples ,
14241424 tokenizer = tokenizer ,
1425- sample_expansion_batch_size = PHASE_I_B_SAMPLE_EXPANSION_BATCH_SIZE )
1425+ sample_expansion_batch_size = PHASE_I_B_SAMPLE_EXPANSION_BATCH_SIZE ,
1426+ model_batch_size = batch_size )
14261427
14271428
14281429 phase_i_b_val_dataset = \
14291430 create_dataset (
14301431 raw_text_samples = phase_i_b_val_samples ,
14311432 tokenizer = tokenizer ,
1432- sample_expansion_batch_size = PHASE_I_B_SAMPLE_EXPANSION_BATCH_SIZE )
1433+ sample_expansion_batch_size = PHASE_I_B_SAMPLE_EXPANSION_BATCH_SIZE ,
1434+ model_batch_size = batch_size )
14331435
14341436
14351437 phase_i_b_history = \
0 commit comments