@@ -583,9 +583,16 @@ def create_streaming_training_data(text_samples: List[str],
583583
584584# Step 3 Train test split
585585
586- train_samples_list_text , test_samples_list_text = train_test_split (
586+ train_samples_list_text , val_plus_test_samples_list_text = train_test_split (
587587 non_instruct_samples ,
588- test_size = 0.2 ,
588+ test_size = 0.3 ,
589+ shuffle = True )
590+
591+ # Val and test set split
592+
593+ val_samples_list_text , test_samples_list_text = train_test_split (
594+ val_plus_test_samples_list_text ,
595+ test_size = 0.3 ,
589596 shuffle = True )
590597
591598del (non_instruct_samples )
@@ -600,6 +607,13 @@ def create_streaming_training_data(text_samples: List[str],
600607)
601608
602609
610+ # Set up step 4 iterator for val set:
611+ x_val_packaged , y_val_packaged = create_streaming_training_data (
612+ text_samples = val_samples_list_text , # Your full dataset of text samples
613+ text_expansion_batch_size = 2 # Expand 2 text samples at a time (~1GB memory)
614+ )
615+
616+
603617# Set up step 4 iterator for test set:
604618x_test_packaged , y_test_packaged = create_streaming_training_data (
605619 text_samples = test_samples_list_text , # Your full dataset of text samples
@@ -876,7 +890,8 @@ def from_config(cls, config):
876890 output_shapes = OUTPUT_SHAPES ,
877891 training_data = x_train_packaged ,
878892 labels = y_train_packaged ,
879- validation_split = 0.2 ,
893+ validation_split = 0.0 ,
894+ validation_data = (x_val_packaged , y_val_packaged ),
880895 direction = 'maximize' ,
881896 metric_to_rank_by = "val_categorical_accuracy" ,
882897 minimum_levels = minimum_levels ,
0 commit comments