@@ -860,7 +860,13 @@ def call(self, inputs):
860860 half_sample = sample [:half_sample_len ]
861861
862862 # Tokenize the text
863- half_sample_tokenized = tokenizer (half_sample )
863+ half_sample_tokenized = tokenizer (
864+ half_sample ,
865+ max_length = max_seq_length ,
866+ padding = 'max_length' ,
867+ truncation = True ,
868+ add_special_tokens = False
869+ )['input_ids' ]
864870
865871 # Extract token IDs as a list of integers (not tensors)
866872 if isinstance (half_sample_tokenized , dict ):
@@ -964,21 +970,27 @@ def call(self, inputs):
964970 half_sample = sample [:half_sample_len ]
965971
966972 # Tokenize the text
967- half_sample_tokenized = tokenizer (half_sample )
973+ half_sample_tokenized = tokenizer (
974+ half_sample ,
975+ max_length = max_seq_length ,
976+ padding = 'max_length' ,
977+ truncation = True ,
978+ add_special_tokens = False
979+ )['input_ids' ]
968980
969- # Extract token IDs as a list of integers (not tensors)
970- if isinstance (half_sample_tokenized , dict ):
971- # If tokenizer returns a dict, extract the token IDs
972- token_ids = half_sample_tokenized ['input_ids' ] # or 'token_ids' depending on your tokenizer
973- else :
974- # If tokenizer returns a list directly
975- token_ids = half_sample_tokenized
981+ # # Extract token IDs as a list of integers (not tensors)
982+ # if isinstance(half_sample_tokenized, dict):
983+ # # If tokenizer returns a dict, extract the token IDs
984+ # token_ids = half_sample_tokenized['input_ids'] # or 'token_ids' depending on your tokenizer
985+ # else:
986+ # # If tokenizer returns a list directly
987+ # token_ids = half_sample_tokenized
976988
977- # Convert to Python list of integers if it's a tensor
978- if hasattr (token_ids , 'numpy' ):
979- token_ids = token_ids .numpy ().tolist ()
980- if not isinstance (token_ids , list ):
981- token_ids = list (token_ids )
989+ # # Convert to Python list of integers if it's a tensor
990+ # if hasattr(token_ids, 'numpy'):
991+ # token_ids = token_ids.numpy().tolist()
992+ # if not isinstance(token_ids, list):
993+ # token_ids = list(token_ids)
982994
983995 # Now pass the list of integers to your generate method
984996 generated_tokens = reconstituted_generator .generate (
0 commit comments