Skip to content

Commit ca7b34f

Browse files
Update phishing_email_detection_gpt2.py
Fix unsupported validation_split for non-numpy-equivalent ...
1 parent 96d39fe commit ca7b34f

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

phishing_email_detection_gpt2.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,16 @@ def create_streaming_training_data(text_samples: List[str],
583583

584584
# Step 3 Train test split
585585

586-
train_samples_list_text, test_samples_list_text = train_test_split(
586+
train_samples_list_text, val_plus_test_samples_list_text = train_test_split(
587587
non_instruct_samples,
588-
test_size=0.2,
588+
test_size=0.3,
589+
shuffle=True)
590+
591+
# Val and test set split
592+
593+
val_samples_list_text, test_samples_list_text = train_test_split(
594+
val_plus_test_samples_list_text,
595+
test_size=0.3,
589596
shuffle=True)
590597

591598
del(non_instruct_samples)
@@ -600,6 +607,13 @@ def create_streaming_training_data(text_samples: List[str],
600607
)
601608

602609

610+
# Set up step 4 iterator for val set:
611+
x_val_packaged, y_val_packaged = create_streaming_training_data(
612+
text_samples=val_samples_list_text, # Your full dataset of text samples
613+
text_expansion_batch_size=2 # Expand 2 text samples at a time (~1GB memory)
614+
)
615+
616+
603617
# Set up step 4 iterator for test set:
604618
x_test_packaged, y_test_packaged = create_streaming_training_data(
605619
text_samples=test_samples_list_text, # Your full dataset of text samples
@@ -876,7 +890,8 @@ def from_config(cls, config):
876890
output_shapes=OUTPUT_SHAPES,
877891
training_data=x_train_packaged,
878892
labels=y_train_packaged,
879-
validation_split=0.2,
893+
validation_split=0.0,
894+
validation_data=(x_val_packaged, y_val_packaged),
880895
direction='maximize',
881896
metric_to_rank_by="val_categorical_accuracy",
882897
minimum_levels=minimum_levels,

0 commit comments

Comments
 (0)