Skip to content

Commit 0eeff75

Browse files
Update phishing_email_detection_gpt2.py
Test some tweaks and corrections on generation examples ...
1 parent 3bc1800 commit 0eeff75

File tree

1 file changed

+34
-150
lines changed

1 file changed

+34
-150
lines changed

phishing_email_detection_gpt2.py

Lines changed: 34 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def generate(self, token_ids, do_sample=False, max_new_tokens=None):
786786
current_tokens = token_ids.copy()
787787

788788
# Autoregressive generation loop
789-
temp_gen_count = 0 # <--------<< Debug code to remove later
789+
# temp_gen_count = 0 # <--------<< Debug code to remove later
790790
for _ in range(max_new_tokens):
791791
# Pad or truncate to max_sequence_length (CORRECTED PADDING LOGIC)
792792
if len(current_tokens) > self.max_sequence_length:
@@ -809,11 +809,11 @@ def generate(self, token_ids, do_sample=False, max_new_tokens=None):
809809
# Greedy sampling (argmax)
810810
next_token_id = int(tf.argmax(logits[0], axis=-1).numpy())
811811
# Debug code to removel later
812-
print(f"Generating {temp_gen_count}")
813-
print(f"... next_token_id: {next_token_id}")
814-
next_word = tokenizer.decode(next_token_id)
815-
print(f"Next decoded word: {next_word}")
816-
temp_gen_count +=1
812+
# print(f"Generating {temp_gen_count}")
813+
# print(f"... next_token_id: {next_token_id}")
814+
# next_word = tokenizer.decode(next_token_id)
815+
# print(f"Next decoded word: {next_word}")
816+
# temp_gen_count +=1
817817

818818
# Check for termination condition
819819
if next_token_id == self.padding_token:
@@ -854,16 +854,38 @@ def call(self, inputs):
854854

855855
print("########### BEFORE SEARIALIZING THE GENERATIVE MODEL")
856856

857+
def complete_text(text):
858+
input_ids = tokenizer(
859+
half_sample,
860+
add_special_tokens=False
861+
)['input_ids']
862+
863+
generated_tokens = generator.generate(
864+
token_ids=token_ids, # Just the actual tokens, no padding
865+
do_sample=False,
866+
max_new_tokens=40
867+
)
868+
generated_text =\
869+
tokenizer.decode(generated_tokens).replace(text, "")
870+
rerurn generated_text
871+
872+
test_text = "I saw the sun and it was"
873+
response = complete_text(test_text)
874+
875+
print(f"I ask the generator: {test_text}... It responds:")
876+
print(response)
877+
857878
counter = 0
858879
for sample in non_instruct_samples:
859-
half_sample_len = int(np.ceil(len(sample) / 2))
860-
half_sample = sample[:half_sample_len]
880+
861881

862882
# Tokenize the text without padding first to get actual tokens
863-
half_sample_tokenized = tokenizer(
864-
half_sample,
883+
sample_tokenized = tokenizer(
884+
sample,
865885
add_special_tokens=False
866886
)['input_ids']
887+
half_index = int(np.ceil(len(sample_tokenized) * 0.5))
888+
half_sample_tokenized = sample_tokenized[:half_index]
867889

868890
# Convert to Python list of integers
869891
if hasattr(half_sample_tokenized, 'numpy'):
@@ -883,62 +905,10 @@ def call(self, inputs):
883905

884906
# Decode the result
885907
full_generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)
886-
print(f"PROMPT number {counter}: {half_sample}; RESPONSE: {full_generated_text}")
908+
print(f"PROMPT number {counter}: {half_sample}; RESPONSE: {full_generated_text.replace()}")
887909
counter += 1
888910

889911

890-
891-
892-
893-
# # Process ALL original samples from data - REAL WORLD USAGE
894-
# generated_texts = []
895-
# for i, original_text in enumerate(data[:5]): # Process first 5 samples
896-
# print(f"\nProcessing sample {i+1}...")
897-
898-
# # Extract prompt part (everything up to and including </prompt>)
899-
# if '</prompt>' in original_text:
900-
# prompt_part = original_text.split('</prompt>')[0] + '</prompt>'
901-
# else:
902-
# prompt_part = original_text
903-
904-
# # Tokenize the prompt part
905-
# tokenized = tokenizer(
906-
# prompt_part,
907-
# add_special_tokens=False, # We handle special tokens manually
908-
# return_tensors=None # Return lists, not tensors
909-
# )
910-
# prompt_tokens = tokenized['input_ids']
911-
912-
# print(f"Original prompt: {prompt_part[:100]}...")
913-
# print(f"Tokenized prompt length: {len(prompt_tokens)} tokens")
914-
915-
# # Generate tokens using the wrapper class - REAL WORLD USAGE
916-
# generated_tokens = generator.generate(
917-
# token_ids=prompt_tokens,
918-
# do_sample=False,
919-
# max_new_tokens=100
920-
# )
921-
922-
# # Decode the full generated text
923-
# full_generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)
924-
925-
# # Extract just the newly generated part (after the prompt)
926-
# generated_part = full_generated_text[len(prompt_part):]
927-
928-
# generated_texts.append((prompt_part, generated_part))
929-
930-
# print(f"Generated response: {generated_part}...")
931-
932-
# # Display results with proper formatting
933-
# print("\n" + "="*50)
934-
# print("FINAL GENERATED RESULTS")
935-
# print("="*50)
936-
937-
# for idx, (original_prompt, generated_response) in enumerate(generated_texts):
938-
# print(f"\nSample {idx+1}:")
939-
# print(f"Prompt:{original_prompt}")
940-
# print(f"Response: {generated_response}")
941-
942912
# Save the model
943913
model_save_path = f"{TIME}_cerebros-autoregressive-text-generator.keras"
944914
generator.save(model_save_path)
@@ -988,96 +958,9 @@ def call(self, inputs):
988958
counter += 1
989959

990960

991-
# # Test with all original data samples - REAL WORLD DEMO (reconstituted)
992-
# print("\n" + "="*50)
993-
# print("GENERATED TEXT SAMPLES FROM ALL DATA - REAL WORLD USAGE (reconstituted)")
994-
# print("="*50)
995-
996-
# generated_texts_all = []
997-
# for i, text in enumerate(data):
998-
# # Extract prompt part (everything up to and including </prompt>)
999-
# if '</prompt>' in text:
1000-
# prompt_text = text.split('</prompt>')[0] + '</prompt>'
1001-
# else:
1002-
# prompt_text = text
1003-
1004-
# # Tokenize the prompt part for model input
1005-
# tokenized = tokenizer(
1006-
# prompt_text,
1007-
# max_length=MAX_SEQ_LENGTH,
1008-
# padding='max_length',
1009-
# truncation=True,
1010-
# add_special_tokens=False
1011-
# )
1012-
# token_ids = tokenized['input_ids']
1013-
1014-
# # Generate using the reconstituted model
1015-
# generated_token_ids = reconstituted_generator.generate(
1016-
# token_ids=token_ids,
1017-
# do_sample=False,
1018-
# max_new_tokens=100
1019-
# )
1020-
1021-
# # Decode generated text
1022-
# generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=False)
1023-
# generated_texts_all.append(generated_text)
1024-
1025-
1026-
# print(f"\nSample {i+1}:")
1027-
# print(f"Prompt: {prompt_text}")
1028-
# print(f"Generated: {generated_text}")
1029-
# # [len(prompt_text):][:200]}...")
1030-
1031961
print("\nAll samples processed with reconstituted model!")
1032962

1033963

1034-
# # Test with all original data samples
1035-
# print("\n" + "="*50)
1036-
# print("GENERATED TEXT SAMPLES FROM ALL DATA")
1037-
# print("="*50)
1038-
1039-
# generated_texts_all = []
1040-
# for i, text in enumerate(data[:3]): # Process first 3 for demo
1041-
# # Split such that everything before </prompt> or the entire text if </prompt> is not present
1042-
# if '</prompt>' in text:
1043-
# prompt_text = text.split('</prompt>')[0] + '</prompt>'
1044-
# else:
1045-
# prompt_text = text
1046-
1047-
# # Tokenize with proper padding
1048-
# tokenized = tokenizer(
1049-
# prompt_text,
1050-
# max_length=MAX_SEQ_LENGTH,
1051-
# padding='max_length',
1052-
# truncation=True,
1053-
# add_special_tokens=False
1054-
# )
1055-
# token_ids = tokenized['input_ids']
1056-
1057-
# # Generate using the reconstituted model
1058-
# generated_token_ids = reconstituted_generator.generate(
1059-
# token_ids=token_ids,
1060-
# do_sample=False,
1061-
# max_new_tokens=100
1062-
# )
1063-
1064-
# # Decode generated text
1065-
# generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=False)
1066-
# generated_texts_all.append(generated_text)
1067-
1068-
# # Extract and print prompt for display
1069-
# if '<prompt>' in text and '</prompt>' in text:
1070-
# display_prompt = text.split('<prompt>')[1].split('</prompt>')[0]
1071-
# else:
1072-
# display_prompt = text[:100] + "..." if len(text) > 100 else text
1073-
1074-
# print(f"\nSample {i+1}:")
1075-
# print(f"Prompt: {text}")
1076-
# print(f"Generated: {generated_text}") # [len(prompt_text):][:200]}...")
1077-
1078-
print("\nAll samples processed!")
1079-
1080-
1081964

1082965
## Model validation
1083966
print("Validation")
@@ -1086,6 +969,7 @@ def call(self, inputs):
1086969
metrics=['accuracy']
1087970
)
1088971

972+
1089973
results = reconstituted_model.evaluate(x_test_packaged, y_test_packaged)
1090974
print("Test loss:", results[0])
1091975
print("Test accuracy:", results[-1])

0 commit comments

Comments
 (0)