Skip to content

Commit 4573702

Browse files
Update phishing_email_detection_gpt2.py
...
1 parent 7a039dd commit 4573702

File tree

1 file changed

+47
-10
lines changed

1 file changed

+47
-10
lines changed

phishing_email_detection_gpt2.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,26 +190,63 @@ def from_config(cls, config):
190190
### Cerebros model:
191191

192192
# TokenizerLayer class to handle tokenization and return only token_ids
193-
class TokenizerLayer(tf.keras.layers.Layer):
194193

195-
def __init__(self, max_seq_length, **kwargs):
196-
super(TokenizerLayer, self).__init__(**kwargs) # Update this line
197-
self.tokenizer = GPT2Tokenizer.from_preset("gpt2_extra_large_en")
198-
self.preprocessor = GPT2Preprocessor(self.tokenizer, sequence_length=max_seq_length)
194+
from transformers import AutoTokenizer
195+
import tensorflow as tf
196+
197+
class NewTokenizerLayer(tf.keras.layers.Layer):
198+
def __init__(self, max_seq_length, tokenizer_checkpoint, **kwargs):
199+
super().__init__(**kwargs)
199200
self.max_seq_length = max_seq_length
201+
self.tokenizer_checkpoint = tokenizer_checkpoint
202+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
203+
204+
# Ensure tokenizer has a padding token
205+
if self.tokenizer.pad_token is None:
206+
self.tokenizer.pad_token = self.tokenizer.eos_token
200207

201208
def call(self, inputs):
202-
prep = self.preprocessor([inputs])
203-
return prep['token_ids']
209+
def tokenize_py_fn(inputs):
210+
# Convert TensorFlow bytes to Python strings
211+
texts = [text.decode('utf-8') for text in inputs.numpy()]
212+
213+
# Tokenize with Hugging Face tokenizer
214+
tokenized = self.tokenizer(
215+
texts,
216+
max_length=self.max_seq_length,
217+
padding='max_length',
218+
truncation=True,
219+
return_tensors='tf'
220+
)
221+
return tokenized['input_ids'].numpy()
222+
223+
# Wrap Python function in TensorFlow operation
224+
input_ids = tf.py_function(
225+
tokenize_py_fn,
226+
[inputs],
227+
Tout=tf.int32
228+
)
229+
230+
# Set shape for downstream layers
231+
batch_size = tf.shape(inputs)[0]
232+
input_ids.set_shape([None, self.max_seq_length])
233+
234+
return input_ids
204235

205236
def get_config(self):
206-
config = super(TokenizerLayer, self).get_config()
207-
config.update({'max_seq_length': self.max_seq_length})
237+
config = super().get_config()
238+
config.update({
239+
'max_seq_length': self.max_seq_length,
240+
'tokenizer_checkpoint': self.tokenizer_checkpoint
241+
})
208242
return config
209243

210244
@classmethod
211245
def from_config(cls, config):
212-
return cls(max_seq_length=config['max_seq_length'])
246+
return cls(
247+
max_seq_length=config['max_seq_length'],
248+
tokenizer_checkpoint=config['tokenizer_checkpoint']
249+
)
213250

214251

215252

0 commit comments

Comments
 (0)