Skip to content

Commit 44946da

Browse files
Update phishing_email_detection_gpt2.py
Add a test that we can load the model and run data through it. May need dummy data reformatted.
1 parent 7d4a951 commit 44946da

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

phishing_email_detection_gpt2.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,3 +545,40 @@ def from_config(cls, config):
545545
# Save the model with custom objects
546546
gpt_baseline_model.save('gpt_baseline_model.h5', save_format='h5', custom_objects=custom_objects)
547547
cerebros_base_model.save('cerebros_base_model.h5', save_format='h5', custom_objects=custom_objects)
548+
549+
# Test loading the models back
550+
print("Testing model loading...")
551+
try:
552+
# Load GPT baseline model
553+
loaded_gpt_model = tf.keras.models.load_model('gpt_baseline_model.h5', custom_objects=custom_objects)
554+
print("✓ GPT baseline model loaded successfully!")
555+
556+
# Verify GPT model structure
557+
print("GPT Model Summary:")
558+
print(loaded_gpt_model.summary())
559+
560+
# Test GPT model prediction
561+
test_input = tf.constant(["This is a test email for phishing detection."])
562+
gpt_prediction = loaded_gpt_model.predict(test_input)
563+
print(f"GPT Model prediction shape: {gpt_prediction.shape}")
564+
print(f"GPT Model prediction sample: {gpt_prediction[0]}")
565+
566+
# Load Cerebros base model
567+
loaded_cerebros_model = tf.keras.models.load_model('cerebros_base_model.h5', custom_objects=custom_objects)
568+
print("✓ Cerebros base model loaded successfully!")
569+
570+
# Verify Cerebros model structure
571+
print("Cerebros Model Summary:")
572+
print(loaded_cerebros_model.summary())
573+
574+
# Test Cerebros model prediction
575+
cerebros_prediction = loaded_cerebros_model.predict(test_input)
576+
print(f"Cerebros Model prediction shape: {cerebros_prediction.shape}")
577+
print(f"Cerebros Model prediction sample shape: {cerebros_prediction[0].shape}")
578+
579+
print("✓ All models loaded and validated successfully!")
580+
581+
except Exception as e:
582+
print(f"✗ Error loading models: {e}")
583+
raise
584+

0 commit comments

Comments
 (0)