@@ -310,3 +310,266 @@ def reset_state(self):
310310 self .total_crossentropy .assign (0.0 )
311311 self .count .assign (0.0 )
312312
313+ @tf .keras .utils .register_keras_serializable ()
314+ class CerebrosNotGPTConfig :
315+ def __init__ (self , max_sequence_length = 1536 , padding_token = None ):
316+ self .max_sequence_length = max_sequence_length
317+ self .padding_token = padding_token
318+
319+ def get_config (self ):
320+ return {
321+ 'max_sequence_length' : self .max_sequence_length ,
322+ 'padding_token' : self .padding_token
323+ # NO model_0 here!
324+ }
325+
326+ @classmethod
327+ def from_config (cls , config ):
328+ return cls (** config ) # No model_0 to handle
329+
330+
331+ @tf .keras .utils .register_keras_serializable ()
332+ class CerebrosNotGPT (tf .keras .Model ):
333+ def __init__ (self , config , model_0 = None , ** kwargs ):
334+ super ().__init__ (** kwargs )
335+ self .config = config
336+ self .max_sequence_length = config .max_sequence_length
337+ self .padding_token = config .padding_token
338+
339+ # Handle model assignment
340+ if model_0 is not None :
341+ self .model = model_0
342+ else :
343+ # This branch is for deserialization - Keras will restore self.model automatically
344+ # if it was a proper Keras layer/model that was added via self.model = some_keras_model
345+ pass
346+
347+ def get_config (self ):
348+ return {
349+ 'config' : self .config .get_config ()
350+ # NO model reference here!
351+ }
352+
353+ @classmethod
354+ def from_config (cls , config ):
355+ config_obj = CerebrosNotGPTConfig .from_config (config ['config' ])
356+ return cls (config = config_obj ) # Keras will handle model restoration
357+
358+ def call (self , inputs ):
359+ return self .model (inputs )
360+
361+ @staticmethod
362+ def apply_top_k_probs (probs , k ):
363+ if k is None or k <= 0 :
364+ return probs
365+ # Flatten and argsort for indices
366+ sorted_indices = tf .argsort (probs , direction = 'DESCENDING' )
367+ keep_indices = sorted_indices [:k ]
368+ mask = tf .zeros_like (probs , dtype = tf .bool )
369+ mask = tf .tensor_scatter_nd_update (mask , tf .reshape (keep_indices , (- 1 , 1 )),
370+ tf .ones ((k ,), dtype = tf .bool ))
371+ filtered_probs = tf .where (mask , probs , tf .zeros_like (probs ))
372+ # Renormalize
373+ filtered_probs = filtered_probs / tf .reduce_sum (filtered_probs )
374+ return filtered_probs
375+
376+ @staticmethod
377+ def apply_top_p_probs (probs , p ):
378+ if p is None or p >= 1.0 :
379+ return probs
380+ sorted_indices = tf .argsort (probs , direction = 'DESCENDING' )
381+ sorted_probs = tf .gather (probs , sorted_indices )
382+ cumulative_probs = tf .cumsum (sorted_probs )
383+ mask = cumulative_probs <= p
384+ # Always keep at least 1 token
385+ mask = tf .concat ([tf .constant ([True ]), mask [1 :]], axis = 0 )
386+ keep_indices = tf .boolean_mask (sorted_indices , mask )
387+ filtered_probs = tf .where (
388+ tf .reduce_any (tf .equal (tf .range (tf .shape (probs )[0 ])[:, None ], keep_indices ), axis = 1 ), probs ,
389+ tf .zeros_like (probs ))
390+ # Renormalize
391+ filtered_probs = filtered_probs / tf .reduce_sum (filtered_probs )
392+ return filtered_probs
393+
394+ def generate (self ,
395+ token_ids ,
396+ do_sample = False ,
397+ max_new_tokens = None ,
398+ temperature = 1.0 ,
399+ top_k = None ,
400+ top_p = None ,
401+ frequency_penalty = None ,
402+ presence_penalty = None ,
403+ repetition_penalty = None ):
404+ """
405+ Generate text autoregressively from token IDs.
406+ Applies filtering in sequence: penalties -> temperature -> top-k -> top-p
407+ """
408+ # Convert token_ids to list if it's not already
409+ if not isinstance (token_ids , list ):
410+ token_ids = list (token_ids )
411+
412+ # Determine the actual maximum number of new tokens
413+ if max_new_tokens is None :
414+ max_new_tokens = self .max_sequence_length - len (token_ids )
415+ else :
416+ max_new_tokens = min (max_new_tokens , self .max_sequence_length - len (token_ids ))
417+
418+ # Initialize the generated tokens list
419+ generated_tokens = []
420+ current_tokens = token_ids .copy ()
421+
422+ # Autoregressive generation loop
423+ for _ in range (max_new_tokens ):
424+ # Pad or truncate to max_sequence_length
425+ if len (current_tokens ) > self .max_sequence_length :
426+ input_tokens = current_tokens [- self .max_sequence_length :]
427+ else :
428+ padding_needed = self .max_sequence_length - len (current_tokens )
429+ input_tokens = current_tokens + [self .padding_token ] * padding_needed
430+
431+ # Convert to tensor and get model prediction
432+ input_tensor = tf .constant ([input_tokens ], dtype = tf .int32 )
433+ probs_nested = self .model (input_tensor )
434+ probs = probs_nested [0 ] # Already softmax probabilities (NOT logits as comment says)
435+ logits = tf .math .log (probs + 10 ** - 20 ) # Convert to logits for penalty application
436+
437+ if do_sample :
438+ # Apply repetition/frequency/presence penalties to logits
439+ if frequency_penalty is not None or presence_penalty is not None :
440+ # Collect token counts from current_tokens
441+ token_counts = {}
442+ for t in current_tokens :
443+ token_counts [t ] = token_counts .get (t , 0 ) + 1
444+
445+ # Prepare penalty tensor
446+ vocab_size = tf .shape (logits )[0 ]
447+ penalties = tf .zeros_like (logits )
448+
449+ for token_id , count in token_counts .items ():
450+ if token_id >= vocab_size :
451+ continue
452+ penalty = 0.0
453+ if presence_penalty is not None :
454+ penalty += presence_penalty
455+ if frequency_penalty is not None :
456+ penalty += frequency_penalty * count
457+
458+ penalties = tf .tensor_scatter_nd_add (
459+ penalties ,
460+ [[token_id ]],
461+ [penalty ]
462+ )
463+
464+ # Subtract penalties from logits
465+ logits = logits - penalties
466+
467+ # Apply repetition penalty (standard approach)
468+ if repetition_penalty is not None and repetition_penalty != 1.0 :
469+ # Collect unique tokens that have appeared
470+ unique_tokens = list (set (current_tokens ))
471+ vocab_size = tf .shape (logits )[0 ]
472+
473+ for token_id in unique_tokens :
474+ if token_id < vocab_size :
475+ # Divide logits of repeated tokens by penalty
476+ logits = tf .tensor_scatter_nd_update (
477+ logits ,
478+ [[token_id ]],
479+ [logits [token_id ] / repetition_penalty ]
480+ )
481+
482+ # Apply temperature
483+ if temperature != 1.0 :
484+ logits = logits / temperature
485+
486+ # Convert to probabilities
487+ probs = tf .nn .softmax (logits )
488+
489+ # Apply top-k filtering (if specified)
490+ if top_k is not None and top_k > 0 :
491+ k = min (top_k , tf .shape (probs )[0 ])
492+ # Get top-k values and indices
493+ top_k_values , top_k_indices = tf .nn .top_k (probs , k = k , sorted = False )
494+ # Create mask for top-k positions
495+ top_k_mask = tf .scatter_nd (
496+ tf .expand_dims (top_k_indices , 1 ),
497+ tf .ones_like (top_k_values , dtype = tf .bool ),
498+ tf .shape (probs )
499+ )
500+ # Zero out non-top-k probabilities
501+ probs = tf .where (top_k_mask , probs , tf .zeros_like (probs ))
502+ # Renormalize
503+ probs = probs / tf .reduce_sum (probs )
504+ print (
505+ f">>> After top_k: { tf .shape (probs )} shape, { tf .reduce_sum (tf .cast (probs > 1e-8 , tf .int32 ))} non-zero probs" )
506+
507+ # Apply top-p filtering (if specified)
508+ if top_p is not None and top_p < 1.0 :
509+ # Sort probabilities in descending order
510+ sorted_indices = tf .argsort (probs , direction = 'DESCENDING' )
511+ sorted_probs = tf .gather (probs , sorted_indices )
512+ cumulative_probs = tf .cumsum (sorted_probs )
513+ # Create mask for top-p
514+ mask = cumulative_probs <= top_p
515+ # Always keep at least one token
516+ mask = tf .concat ([tf .constant ([True ]), mask [1 :]], axis = 0 )
517+ # Get indices to keep
518+ keep_indices = tf .boolean_mask (sorted_indices , mask )
519+ # Create mask for original indices
520+ filter_mask = tf .scatter_nd (
521+ tf .expand_dims (keep_indices , 1 ),
522+ tf .ones_like (keep_indices , dtype = tf .bool ),
523+ tf .shape (probs )
524+ )
525+ # Apply mask and renormalize
526+ probs = tf .where (filter_mask , probs , tf .zeros_like (probs ))
527+ probs = probs / tf .reduce_sum (probs )
528+ print (
529+ f">>> After top_p: { tf .shape (probs )} shape, { tf .reduce_sum (tf .cast (probs > 1e-8 , tf .int32 ))} non-zero probs" )
530+
531+ # Sample from the final filtered distribution
532+ # Get non-zero indices and their probabilities
533+ non_zero_mask = probs > 1e-8
534+ if tf .reduce_any (non_zero_mask ):
535+ filtered_indices = tf .where (non_zero_mask )[:, 0 ] # Get indices
536+ filtered_probs = tf .boolean_mask (probs , non_zero_mask ) # Get probabilities
537+ # Sample
538+ sampled_local_index = tf .random .categorical (tf .math .log (filtered_probs )[None , :], 1 )[0 , 0 ]
539+ # Map back to vocabulary index
540+ next_token_id = int (filtered_indices [sampled_local_index ].numpy ())
541+ else :
542+ # Fallback if all probabilities are zero
543+ warn (
544+ "Token sampling had to revert to greedy sampling, because no probs had a value > 0, unexpected" )
545+ next_token_id = int (tf .argmax (probs , axis = - 1 ).numpy ())
546+
547+ else :
548+ # Greedy sampling (argmax) - apply repetition penalty if needed
549+ if repetition_penalty is not None and repetition_penalty != 1.0 :
550+ unique_tokens = list (set (current_tokens ))
551+ vocab_size = tf .shape (logits )[0 ]
552+ for token_id in unique_tokens :
553+ if token_id < vocab_size :
554+ logits = tf .tensor_scatter_nd_update (
555+ logits ,
556+ [[token_id ]],
557+ [logits [token_id ] / repetition_penalty ]
558+ )
559+
560+ next_token_id = int (tf .argmax (logits , axis = - 1 ).numpy ())
561+
562+ # Check for termination condition
563+ if next_token_id == self .padding_token :
564+ break
565+
566+ # Add to generated tokens and update current tokens
567+ generated_tokens .append (int (next_token_id ))
568+ current_tokens .append (int (next_token_id ))
569+
570+ # Check if we've reached max sequence length
571+ if len (current_tokens ) >= self .max_sequence_length :
572+ break
573+
574+ return token_ids + generated_tokens
575+
0 commit comments