Skip to content

Commit db47c42

Browse files
Update phishing_email_detection_gpt2.py
Syntax error ...
1 parent 8470094 commit db47c42

File tree

1 file changed

+5
-67
lines changed

1 file changed

+5
-67
lines changed

phishing_email_detection_gpt2.py

Lines changed: 5 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -655,76 +655,14 @@ def reset_state(self):
655655
print("GENERATED TEXT SAMPLES")
656656
print("="*50)
657657

658-
# Get pad token id
659-
pad_token_id = tokenizer.pad_token_id
660-
end_prompt_token_id = tokenizer.encode("</prompt>", add_special_tokens=False)[0]
661-
662-
# # Generate text for first 5 test samples (Working)
663-
# generated_texts = []
664-
# for i in range(min(5, len(x_test_packaged[0]))):
665-
# original_input = x_test_packaged[0][i].numpy()
666-
667-
# # Find the end of the prompt
668-
# try:
669-
# end_prompt_index = list(original_input).index(end_prompt_token_id)
670-
# except ValueError:
671-
# end_prompt_index = 0
672-
673-
# # Extract the prompt part
674-
# prompt_tokens = original_input[:end_prompt_index+1].tolist()
675-
676-
# # Generate tokens sequentially
677-
# generated_tokens = []
678-
# current_input = prompt_tokens.copy()
679-
680-
# # Generate up to 100 tokens or until pad token
681-
# for _ in range(100):
682-
# # Pad or truncate to MAX_SEQ_LENGTH
683-
# input_tensor = tf.constant([current_input + [pad_token_id] * (MAX_SEQ_LENGTH - len(current_input))], dtype=tf.int32)
684-
685-
# # Get prediction
686-
# prediction = reconstituted_model(input_tensor)
687-
# next_token_id = int(tf.argmax(prediction[0], axis=-1).numpy())
688-
689-
# # Stop if pad token generated
690-
# if next_token_id == pad_token_id:
691-
# break
692-
693-
# generated_tokens.append(next_token_id)
694-
# current_input.append(next_token_id)
695-
696-
# # Stop if we exceed max length
697-
# if len(current_input) >= MAX_SEQ_LENGTH:
698-
# break
699-
700-
# generated_texts.append((prompt_tokens, generated_tokens))
701-
702-
# # Decode and print with proper formatting
703-
# for idx, (prompt_tokens, generated_tokens) in enumerate(generated_texts):
704-
# # Decode prompt
705-
# prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=False)
706-
707-
# # Extract original prompt content
708-
# if '<prompt>' in prompt_text and '</prompt>' in prompt_text:
709-
# original_prompt = prompt_text.split('<prompt>')[-1].split('</prompt>')[0]
710-
# else:
711-
# original_prompt = prompt_text[:50] + "..." if len(prompt_text) > 50 else prompt_text
712-
713-
# # Decode generated text
714-
# generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False) if generated_tokens else ""
715-
716-
# print(f"\nGenerated text from sample {idx+1}:")
717-
# print(f"<prompt>{original_prompt}</prompt>{generated_text}")
718-
719-
720658

721659
## Proper model wrapper and generation method (under development):
722660

723661
print("###### Output of the model wrapper (under development) ########### ")
724662

725663
# Register the config and model wrapper as serializable
726664
@tf.keras.utils.register_keras_serializable()
727-
class CerebrosAutoregressiveTextGeneratorConfig:
665+
class CerebrosNotGPTConfig:
728666
def __init__(self, max_sequence_length=1536, padding_token=None):
729667
self.max_sequence_length = max_sequence_length
730668
self.padding_token = padding_token
@@ -740,7 +678,7 @@ def from_config(cls, config):
740678
return cls(**config)
741679

742680
@tf.keras.utils.register_keras_serializable()
743-
class CerebrosAutoregressiveTextGenerator(tf.keras.Model):
681+
class CerebrosNotGPT(tf.keras.Model):
744682
def __init__(self, config, **kwargs):
745683
super().__init__(**kwargs)
746684
self.config = config
@@ -846,17 +784,17 @@ def call(self, inputs):
846784
print("="*50)
847785

848786
# Create config and generator
849-
config = CerebrosAutoregressiveTextGeneratorConfig(
787+
config = CerebrosNotGPTConfig(
850788
max_sequence_length=MAX_SEQ_LENGTH,
851789
padding_token=tokenizer.pad_token_id
852790
)
853-
generator = CerebrosAutoregressiveTextGenerator(config)
791+
generator = CerebrosNotGPT(config)
854792

855793
print("########### BEFORE SEARIALIZING THE GENERATIVE MODEL")
856794

857795
def complete_text(text):
858796
input_ids = tokenizer(
859-
half_sample,
797+
text,
860798
add_special_tokens=False
861799
)['input_ids']
862800

0 commit comments

Comments
 (0)