@@ -54,7 +54,11 @@ def objective(trial: optuna.Trial) -> float:
5454 import numpy as np
5555 from cerebros .simplecerebrosrandomsearch .simple_cerebros_random_search \
5656 import SimpleCerebrosRandomSearch
57- from cerebrosllmutils .llm_utils import prepare_data , InterleavedRoPE
57+ from cerebrosllmutils .llm_utils import prepare_data , \
58+ InterleavedRoPE , \
59+ Perplexity , \
60+ CerebrosNotGPTConfig , \
61+ CerebrosNotGPT
5862 import pendulum
5963 from cerebros .units .units import DenseUnit
6064 from cerebros .denseautomlstructuralcomponent .dense_automl_structural_component \
@@ -583,37 +587,8 @@ def objective(trial: optuna.Trial) -> float:
583587
584588 meta_trial_number = 42 # irrelevant unless in distributed training
585589
586- # Custom metric: Perplexity:
587-
588- @tf .keras .utils .register_keras_serializable ()
589- class Perplexity (tf .keras .metrics .Metric ):
590- """
591- Computes perplexity, defined as e^(categorical crossentropy).
592- """
593- def __init__ (self , name = 'perplexity' , ** kwargs ):
594- super ().__init__ (name = name , ** kwargs )
595- self .total_crossentropy = self .add_weight (name = 'total_crossentropy' , initializer = 'zeros' )
596- self .count = self .add_weight (name = 'count' , initializer = 'zeros' )
597-
598- def update_state (self , y_true , y_pred , sample_weight = None ):
599- # Calculate categorical crossentropy
600- crossentropy = tf .keras .losses .categorical_crossentropy (y_true , y_pred )
601-
602- # Update the running sum of crossentropy and the count of samples
603- self .total_crossentropy .assign_add (tf .reduce_sum (crossentropy ))
604- self .count .assign_add (tf .cast (tf .shape (y_true )[0 ], dtype = tf .float32 ))
605-
606- def result (self ):
607- # Compute the average crossentropy
608- average_crossentropy = self .total_crossentropy / self .count
609- # Compute perplexity as e^(average crossentropy)
610- return tf .exp (average_crossentropy )
611-
612- def reset_state (self ):
613- # Reset the state variables
614- self .total_crossentropy .assign (0.0 )
615- self .count .assign (0.0 )
616590
591+ # Custom metric: Perplexity:
617592 perplexity_metric = Perplexity ()
618593
619594 cerebros_automl = SimpleCerebrosRandomSearch (
@@ -709,257 +684,9 @@ def reset_state(self):
709684 print ("=" * 50 )
710685
711686
712- # Register the config and model wrapper as serializable
713- @tf .keras .utils .register_keras_serializable ()
714- class CerebrosNotGPTConfig :
715- def __init__ (self , max_sequence_length = 1536 , padding_token = None ):
716- self .max_sequence_length = max_sequence_length
717- self .padding_token = padding_token
718-
719- def get_config (self ):
720- return {
721- 'max_sequence_length' : self .max_sequence_length ,
722- 'padding_token' : self .padding_token
723- }
724-
725- @classmethod
726- def from_config (cls , config ):
727- return cls (** config )
728-
729- @tf .keras .utils .register_keras_serializable ()
730- class CerebrosNotGPT (tf .keras .Model ):
731- def __init__ (self , config , ** kwargs ):
732- super ().__init__ (** kwargs )
733- self .config = config
734- self .max_sequence_length = config .max_sequence_length
735- self .padding_token = config .padding_token
736- # Make self.model = the reconstituted model (constant)
737- self .model = best_model_found # reconstituted_model
738-
739- def get_config (self ):
740- return {
741- 'config' : self .config .get_config ()
742- }
743-
744- @classmethod
745- def from_config (cls , config ):
746- config_obj = CerebrosNotGPTConfig .from_config (config ['config' ])
747- return cls (config = config_obj )
748-
749- @staticmethod
750- def apply_top_k_probs (probs , k ):
751- if k is None or k <= 0 :
752- return probs
753- # Flatten and argsort for indices
754- sorted_indices = tf .argsort (probs , direction = 'DESCENDING' )
755- keep_indices = sorted_indices [:k ]
756- mask = tf .zeros_like (probs , dtype = tf .bool )
757- mask = tf .tensor_scatter_nd_update (mask , tf .reshape (keep_indices , (- 1 ,1 )), tf .ones ((k ,), dtype = tf .bool ))
758- filtered_probs = tf .where (mask , probs , tf .zeros_like (probs ))
759- # Renormalize
760- filtered_probs = filtered_probs / tf .reduce_sum (filtered_probs )
761- return filtered_probs
762-
763- @staticmethod
764- def apply_top_p_probs (probs , p ):
765- if p is None or p >= 1.0 :
766- return probs
767- sorted_indices = tf .argsort (probs , direction = 'DESCENDING' )
768- sorted_probs = tf .gather (probs , sorted_indices )
769- cumulative_probs = tf .cumsum (sorted_probs )
770- mask = cumulative_probs <= p
771- # Always keep at least 1 token
772- mask = tf .concat ([tf .constant ([True ]), mask [1 :]], axis = 0 )
773- keep_indices = tf .boolean_mask (sorted_indices , mask )
774- filtered_probs = tf .where (tf .reduce_any (tf .equal (tf .range (tf .shape (probs )[0 ])[:,None ], keep_indices ), axis = 1 ), probs , tf .zeros_like (probs ))
775- # Renormalize
776- filtered_probs = filtered_probs / tf .reduce_sum (filtered_probs )
777- return filtered_probs
778-
779-
780- def generate (self ,
781- token_ids ,
782- do_sample = False ,
783- max_new_tokens = None ,
784- temperature = 1.0 ,
785- top_k = None ,
786- top_p = None ,
787- frequency_penalty = None ,
788- presence_penalty = None ,
789- repetition_penalty = None ):
790- """
791- Generate text autoregressively from token IDs.
792- Applies filtering in sequence: penalties -> temperature -> top-k -> top-p
793- """
794- # Convert token_ids to list if it's not already
795- if not isinstance (token_ids , list ):
796- token_ids = list (token_ids )
797-
798- # Determine the actual maximum number of new tokens
799- if max_new_tokens is None :
800- max_new_tokens = self .max_sequence_length - len (token_ids )
801- else :
802- max_new_tokens = min (max_new_tokens , self .max_sequence_length - len (token_ids ))
803-
804- # Initialize the generated tokens list
805- generated_tokens = []
806- current_tokens = token_ids .copy ()
807-
808- # Autoregressive generation loop
809- for _ in range (max_new_tokens ):
810- # Pad or truncate to max_sequence_length
811- if len (current_tokens ) > self .max_sequence_length :
812- input_tokens = current_tokens [- self .max_sequence_length :]
813- else :
814- padding_needed = self .max_sequence_length - len (current_tokens )
815- input_tokens = current_tokens + [self .padding_token ] * padding_needed
816-
817- # Convert to tensor and get model prediction
818- input_tensor = tf .constant ([input_tokens ], dtype = tf .int32 )
819- probs_nested = self .model (input_tensor )
820- probs = probs_nested [0 ] # Already softmax probabilities (NOT logits as comment says)
821- logits = tf .math .log (probs + 10 ** - 20 ) # Convert to logits for penalty application
822-
823- if do_sample :
824- # Apply repetition/frequency/presence penalties to logits
825- if frequency_penalty is not None or presence_penalty is not None :
826- # Collect token counts from current_tokens
827- token_counts = {}
828- for t in current_tokens :
829- token_counts [t ] = token_counts .get (t , 0 ) + 1
830-
831- # Prepare penalty tensor
832- vocab_size = tf .shape (logits )[0 ]
833- penalties = tf .zeros_like (logits )
834-
835- for token_id , count in token_counts .items ():
836- if token_id >= vocab_size :
837- continue
838- penalty = 0.0
839- if presence_penalty is not None :
840- penalty += presence_penalty
841- if frequency_penalty is not None :
842- penalty += frequency_penalty * count
843-
844- penalties = tf .tensor_scatter_nd_add (
845- penalties ,
846- [[token_id ]],
847- [penalty ]
848- )
849-
850- # Subtract penalties from logits
851- logits = logits - penalties
852-
853- # Apply repetition penalty (standard approach)
854- if repetition_penalty is not None and repetition_penalty != 1.0 :
855- # Collect unique tokens that have appeared
856- unique_tokens = list (set (current_tokens ))
857- vocab_size = tf .shape (logits )[0 ]
858-
859- for token_id in unique_tokens :
860- if token_id < vocab_size :
861- # Divide logits of repeated tokens by penalty
862- logits = tf .tensor_scatter_nd_update (
863- logits ,
864- [[token_id ]],
865- [logits [token_id ] / repetition_penalty ]
866- )
867-
868- # Apply temperature
869- if temperature != 1.0 :
870- logits = logits / temperature
871-
872- # Convert to probabilities
873- probs = tf .nn .softmax (logits )
874-
875- # Apply top-k filtering (if specified)
876- if top_k is not None and top_k > 0 :
877- k = min (top_k , tf .shape (probs )[0 ])
878- # Get top-k values and indices
879- top_k_values , top_k_indices = tf .nn .top_k (probs , k = k , sorted = False )
880- # Create mask for top-k positions
881- top_k_mask = tf .scatter_nd (
882- tf .expand_dims (top_k_indices , 1 ),
883- tf .ones_like (top_k_values , dtype = tf .bool ),
884- tf .shape (probs )
885- )
886- # Zero out non-top-k probabilities
887- probs = tf .where (top_k_mask , probs , tf .zeros_like (probs ))
888- # Renormalize
889- probs = probs / tf .reduce_sum (probs )
890- print (f">>> After top_k: { tf .shape (probs )} shape, { tf .reduce_sum (tf .cast (probs > 1e-8 , tf .int32 ))} non-zero probs" )
891-
892- # Apply top-p filtering (if specified)
893- if top_p is not None and top_p < 1.0 :
894- # Sort probabilities in descending order
895- sorted_indices = tf .argsort (probs , direction = 'DESCENDING' )
896- sorted_probs = tf .gather (probs , sorted_indices )
897- cumulative_probs = tf .cumsum (sorted_probs )
898- # Create mask for top-p
899- mask = cumulative_probs <= top_p
900- # Always keep at least one token
901- mask = tf .concat ([tf .constant ([True ]), mask [1 :]], axis = 0 )
902- # Get indices to keep
903- keep_indices = tf .boolean_mask (sorted_indices , mask )
904- # Create mask for original indices
905- filter_mask = tf .scatter_nd (
906- tf .expand_dims (keep_indices , 1 ),
907- tf .ones_like (keep_indices , dtype = tf .bool ),
908- tf .shape (probs )
909- )
910- # Apply mask and renormalize
911- probs = tf .where (filter_mask , probs , tf .zeros_like (probs ))
912- probs = probs / tf .reduce_sum (probs )
913- print (f">>> After top_p: { tf .shape (probs )} shape, { tf .reduce_sum (tf .cast (probs > 1e-8 , tf .int32 ))} non-zero probs" )
914-
915- # Sample from the final filtered distribution
916- # Get non-zero indices and their probabilities
917- non_zero_mask = probs > 1e-8
918- if tf .reduce_any (non_zero_mask ):
919- filtered_indices = tf .where (non_zero_mask )[:, 0 ] # Get indices
920- filtered_probs = tf .boolean_mask (probs , non_zero_mask ) # Get probabilities
921- # Sample
922- sampled_local_index = tf .random .categorical (tf .math .log (filtered_probs )[None , :], 1 )[0 , 0 ]
923- # Map back to vocabulary index
924- next_token_id = int (filtered_indices [sampled_local_index ].numpy ())
925- else :
926- # Fallback if all probabilities are zero
927- warn ("Token sampling had to revert to greedy sampling, because no probs had a value > 0, unexpected" )
928- next_token_id = int (tf .argmax (probs , axis = - 1 ).numpy ())
929-
930- else :
931- # Greedy sampling (argmax) - apply repetition penalty if needed
932- if repetition_penalty is not None and repetition_penalty != 1.0 :
933- unique_tokens = list (set (current_tokens ))
934- vocab_size = tf .shape (logits )[0 ]
935- for token_id in unique_tokens :
936- if token_id < vocab_size :
937- logits = tf .tensor_scatter_nd_update (
938- logits ,
939- [[token_id ]],
940- [logits [token_id ] / repetition_penalty ]
941- )
942-
943- next_token_id = int (tf .argmax (logits , axis = - 1 ).numpy ())
944-
945- # Check for termination condition
946- if next_token_id == self .padding_token :
947- break
948-
949- # Add to generated tokens and update current tokens
950- generated_tokens .append (int (next_token_id ))
951- current_tokens .append (int (next_token_id ))
952-
953- # Check if we've reached max sequence length
954- if len (current_tokens ) >= self .max_sequence_length :
955- break
956-
957- return token_ids + generated_tokens
958687
688+
959689
960- def call (self , inputs ):
961- # This is just for compatibility, the main logic is in generate()
962- return self .model (inputs )
963690
964691 # Replace the generation code block with this:
965692
0 commit comments