@@ -789,62 +789,34 @@ def from_config(cls, config):
789789 config_obj = CerebrosNotGPTConfig .from_config (config ['config' ])
790790 return cls (config = config_obj )
791791
792- def generate (self , token_ids , do_sample = False , max_new_tokens = None ):
793- """
794- Generate text autoregressively from token IDs.
792+ def generate (self , token_ids , do_sample = False , max_new_tokens = None , temperature = 1.0 , top_k = None , top_p = None ):
793+ # (init code as existing)
795794
796- Args:
797- token_ids: Iterable of integers representing token IDs
798- do_sample: Boolean, if True use sampling, if False use greedy argmax
799- max_new_tokens: Maximum number of new tokens to generate
800-
801- Returns:
802- List of token IDs including original tokens and generated tokens
803- """
804- # Convert token_ids to list if it's not already
805- if not isinstance (token_ids , list ):
806- token_ids = list (token_ids )
807-
808- # Determine the actual maximum number of new tokens
809- if max_new_tokens is None :
810- max_new_tokens = self .max_sequence_length - len (token_ids )
811- else :
812- max_new_tokens = min (max_new_tokens , self .max_sequence_length - len (token_ids ))
813-
814- # Initialize the generated tokens list
815- generated_tokens = []
816- current_tokens = token_ids .copy ()
817-
818- # Autoregressive generation loop
819- # temp_gen_count = 0 # <--------<< Debug code to remove later
820795 for _ in range (max_new_tokens ):
821- # Pad or truncate to max_sequence_length (CORRECTED PADDING LOGIC)
822- if len (current_tokens ) > self .max_sequence_length :
823- input_tokens = current_tokens [:self .max_sequence_length ]
824- else :
825- # Manual padding with padding token
826- padding_needed = self .max_sequence_length - len (current_tokens )
827- input_tokens = current_tokens + [self .padding_token ] * padding_needed
828-
829- # Convert to tensor and get model prediction
796+ # (padding code as existing)
830797 input_tensor = tf .constant ([input_tokens ], dtype = tf .int32 )
831- logits = self .model (input_tensor ) # Shape: (batch_size, VOCABULARY_SIZE)
798+ logits = self .model (input_tensor )
799+
800+ # Apply temperature scaling (logits->probs because your model returns softmax)
801+ probs = logits [0 ] # logits[0] is already softmax
832802
833- # Get next token based on sampling strategy
834803 if do_sample :
835- # Sample from the distribution
836- # probabilities = tf.nn.softmax(logits[0], axis=-1) # Model already applies softmax
837- next_token_id = tf .random .categorical (tf .math .log (logits [0 ])[None , :], 1 )[0 , 0 ].numpy ()
804+ # 1. Temperature: convert back to logits, scale, resoftmax
805+ if temperature != 1.0 :
806+ temp_logits = tf .math .log (probs + 1e-20 ) / temperature
807+ probs = tf .nn .softmax (temp_logits )
808+ # 2. Top-k filtering
809+ if top_k is not None and top_k > 0 :
810+ probs = apply_top_k_probs (probs , top_k )
811+ # 3. Top-p filtering
812+ if top_p is not None and top_p < 1.0 :
813+ probs = apply_top_p_probs (probs , top_p )
814+ # Sample
815+ next_token_id = tf .random .categorical (tf .math .log (probs [None , :]), 1 )[0 ,0 ].numpy ()
838816 else :
839817 # Greedy sampling (argmax)
840- next_token_id = int (tf .argmax (logits [0 ], axis = - 1 ).numpy ())
841- # Debug code to removel later
842- # print(f"Generating {temp_gen_count}")
843- # print(f"... next_token_id: {next_token_id}")
844- # next_word = tokenizer.decode(next_token_id)
845- # print(f"Next decoded word: {next_word}")
846- # temp_gen_count +=1
847-
818+ next_token_id = int (tf .argmax (probs , axis = - 1 ).numpy ())
819+
848820 # Check for termination condition
849821 if next_token_id == self .padding_token :
850822 break
@@ -864,6 +836,8 @@ def generate(self, token_ids, do_sample=False, max_new_tokens=None):
864836 total_tokens .extend ([self .padding_token ] * padding_needed )
865837
866838 return total_tokens
839+
840+
867841
868842 def call (self , inputs ):
869843 # This is just for compatibility, the main logic is in generate()
@@ -929,9 +903,13 @@ def complete_text(text):
929903
930904 # Now pass the list of integers to your generate method
931905 generated_tokens = generator .generate (
906+ # do_sample=False, max_new_tokens=None, temperature=1.0, top_k=None, top_p=None
932907 token_ids = token_ids , # Just the actual tokens, no padding
933- do_sample = False ,
934- max_new_tokens = 40
908+ do_sample = True ,
909+ max_new_tokens = 20 ,
910+ temperature = 0.6 ,
911+ top_k = 20 ,
912+ top_p = 0.9 ,
935913 )
936914
937915 # Decode the result
0 commit comments