@@ -190,26 +190,63 @@ def from_config(cls, config):
190190### Cerebros model:
191191
192192# TokenizerLayer class to handle tokenization and return only token_ids
193- class TokenizerLayer (tf .keras .layers .Layer ):
194193
195- def __init__ (self , max_seq_length , ** kwargs ):
196- super (TokenizerLayer , self ).__init__ (** kwargs ) # Update this line
197- self .tokenizer = GPT2Tokenizer .from_preset ("gpt2_extra_large_en" )
198- self .preprocessor = GPT2Preprocessor (self .tokenizer , sequence_length = max_seq_length )
194+ from transformers import AutoTokenizer
195+ import tensorflow as tf
196+
197+ class NewTokenizerLayer (tf .keras .layers .Layer ):
198+ def __init__ (self , max_seq_length , tokenizer_checkpoint , ** kwargs ):
199+ super ().__init__ (** kwargs )
199200 self .max_seq_length = max_seq_length
201+ self .tokenizer_checkpoint = tokenizer_checkpoint
202+ self .tokenizer = AutoTokenizer .from_pretrained (tokenizer_checkpoint )
203+
204+ # Ensure tokenizer has a padding token
205+ if self .tokenizer .pad_token is None :
206+ self .tokenizer .pad_token = self .tokenizer .eos_token
200207
201208 def call (self , inputs ):
202- prep = self .preprocessor ([inputs ])
203- return prep ['token_ids' ]
209+ def tokenize_py_fn (inputs ):
210+ # Convert TensorFlow bytes to Python strings
211+ texts = [text .decode ('utf-8' ) for text in inputs .numpy ()]
212+
213+ # Tokenize with Hugging Face tokenizer
214+ tokenized = self .tokenizer (
215+ texts ,
216+ max_length = self .max_seq_length ,
217+ padding = 'max_length' ,
218+ truncation = True ,
219+ return_tensors = 'tf'
220+ )
221+ return tokenized ['input_ids' ].numpy ()
222+
223+ # Wrap Python function in TensorFlow operation
224+ input_ids = tf .py_function (
225+ tokenize_py_fn ,
226+ [inputs ],
227+ Tout = tf .int32
228+ )
229+
230+ # Set shape for downstream layers
231+ batch_size = tf .shape (inputs )[0 ]
232+ input_ids .set_shape ([None , self .max_seq_length ])
233+
234+ return input_ids
204235
205236 def get_config (self ):
206- config = super (TokenizerLayer , self ).get_config ()
207- config .update ({'max_seq_length' : self .max_seq_length })
237+ config = super ().get_config ()
238+ config .update ({
239+ 'max_seq_length' : self .max_seq_length ,
240+ 'tokenizer_checkpoint' : self .tokenizer_checkpoint
241+ })
208242 return config
209243
210244 @classmethod
211245 def from_config (cls , config ):
212- return cls (max_seq_length = config ['max_seq_length' ])
246+ return cls (
247+ max_seq_length = config ['max_seq_length' ],
248+ tokenizer_checkpoint = config ['tokenizer_checkpoint' ]
249+ )
213250
214251
215252
0 commit comments