@@ -862,29 +862,29 @@ def call(self, inputs):
862862 # Tokenize the text
863863 half_sample_tokenized = tokenizer (
864864 half_sample ,
865- max_length = max_seq_length ,
865+ max_length = MAX_SEQ_LENGTH ,
866866 padding = 'max_length' ,
867867 truncation = True ,
868868 add_special_tokens = False
869869 )['input_ids' ]
870870
871- # Extract token IDs as a list of integers (not tensors)
872- if isinstance (half_sample_tokenized , dict ):
873- # If tokenizer returns a dict, extract the token IDs
874- token_ids = half_sample_tokenized ['input_ids' ] # or 'token_ids' depending on your tokenizer
875- else :
876- # If tokenizer returns a list directly
877- token_ids = half_sample_tokenized
871+ # # Extract token IDs as a list of integers (not tensors)
872+ # if isinstance(half_sample_tokenized, dict):
873+ # # If tokenizer returns a dict, extract the token IDs
874+ # token_ids = half_sample_tokenized['input_ids'] # or 'token_ids' depending on your tokenizer
875+ # else:
876+ # # If tokenizer returns a list directly
877+ # token_ids = half_sample_tokenized
878878
879- # Convert to Python list of integers if it's a tensor
880- if hasattr (token_ids , 'numpy' ):
881- token_ids = token_ids .numpy ().tolist ()
882- if not isinstance (token_ids , list ):
883- token_ids = list (token_ids )
879+ # # Convert to Python list of integers if it's a tensor
880+ # if hasattr(token_ids, 'numpy'):
881+ # token_ids = token_ids.numpy().tolist()
882+ # if not isinstance(token_ids, list):
883+ # token_ids = list(token_ids)
884884
885885 # Now pass the list of integers to your generate method
886- generated_tokens = generator .generate (
887- token_ids = token_ids , # This should now be a list of integers
886+ generated_tokens = reconstituted_generator .generate (
887+ token_ids = half_sample_tokenized , # This should now be a list of integers
888888 do_sample = False ,
889889 max_new_tokens = 40
890890 )
@@ -972,7 +972,7 @@ def call(self, inputs):
972972 # Tokenize the text
973973 half_sample_tokenized = tokenizer (
974974 half_sample ,
975- max_length = max_seq_length ,
975+ max_length = MAX_SEQ_LENGTH ,
976976 padding = 'max_length' ,
977977 truncation = True ,
978978 add_special_tokens = False
@@ -994,7 +994,7 @@ def call(self, inputs):
994994
995995 # Now pass the list of integers to your generate method
996996 generated_tokens = reconstituted_generator .generate (
997- token_ids = token_ids , # This should now be a list of integers
997+ token_ids = half_sample_tokenized , # This should now be a list of integers
998998 do_sample = False ,
999999 max_new_tokens = 40
10001000 )
0 commit comments