Skip to content

Commit fc6cac7

Browse files
Update generative-proof-of-concept-CPU-preprocessing-in-memory.py
Remove CerebrosNotGPTConfig, CerebrosNotGPT from the main script ...
1 parent c378dbd commit fc6cac7

File tree

1 file changed

+7
-280
lines changed

1 file changed

+7
-280
lines changed

generative-proof-of-concept-CPU-preprocessing-in-memory.py

Lines changed: 7 additions & 280 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ def objective(trial: optuna.Trial) -> float:
5454
import numpy as np
5555
from cerebros.simplecerebrosrandomsearch.simple_cerebros_random_search\
5656
import SimpleCerebrosRandomSearch
57-
from cerebrosllmutils.llm_utils import prepare_data, InterleavedRoPE
57+
from cerebrosllmutils.llm_utils import prepare_data, \
58+
InterleavedRoPE, \
59+
Perplexity, \
60+
CerebrosNotGPTConfig, \
61+
CerebrosNotGPT
5862
import pendulum
5963
from cerebros.units.units import DenseUnit
6064
from cerebros.denseautomlstructuralcomponent.dense_automl_structural_component\
@@ -583,37 +587,8 @@ def objective(trial: optuna.Trial) -> float:
583587

584588
meta_trial_number = 42 # irrelevant unless in distributed training
585589

586-
# Custom metric: Perplexity:
587-
588-
@tf.keras.utils.register_keras_serializable()
589-
class Perplexity(tf.keras.metrics.Metric):
590-
"""
591-
Computes perplexity, defined as e^(categorical crossentropy).
592-
"""
593-
def __init__(self, name='perplexity', **kwargs):
594-
super().__init__(name=name, **kwargs)
595-
self.total_crossentropy = self.add_weight(name='total_crossentropy', initializer='zeros')
596-
self.count = self.add_weight(name='count', initializer='zeros')
597-
598-
def update_state(self, y_true, y_pred, sample_weight=None):
599-
# Calculate categorical crossentropy
600-
crossentropy = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
601-
602-
# Update the running sum of crossentropy and the count of samples
603-
self.total_crossentropy.assign_add(tf.reduce_sum(crossentropy))
604-
self.count.assign_add(tf.cast(tf.shape(y_true)[0], dtype=tf.float32))
605-
606-
def result(self):
607-
# Compute the average crossentropy
608-
average_crossentropy = self.total_crossentropy / self.count
609-
# Compute perplexity as e^(average crossentropy)
610-
return tf.exp(average_crossentropy)
611-
612-
def reset_state(self):
613-
# Reset the state variables
614-
self.total_crossentropy.assign(0.0)
615-
self.count.assign(0.0)
616590

591+
# Custom metric: Perplexity:
617592
perplexity_metric = Perplexity()
618593

619594
cerebros_automl = SimpleCerebrosRandomSearch(
@@ -709,257 +684,9 @@ def reset_state(self):
709684
print("="*50)
710685

711686

712-
# Register the config and model wrapper as serializable
713-
@tf.keras.utils.register_keras_serializable()
714-
class CerebrosNotGPTConfig:
715-
def __init__(self, max_sequence_length=1536, padding_token=None):
716-
self.max_sequence_length = max_sequence_length
717-
self.padding_token = padding_token
718-
719-
def get_config(self):
720-
return {
721-
'max_sequence_length': self.max_sequence_length,
722-
'padding_token': self.padding_token
723-
}
724-
725-
@classmethod
726-
def from_config(cls, config):
727-
return cls(**config)
728-
729-
@tf.keras.utils.register_keras_serializable()
730-
class CerebrosNotGPT(tf.keras.Model):
731-
def __init__(self, config, **kwargs):
732-
super().__init__(**kwargs)
733-
self.config = config
734-
self.max_sequence_length = config.max_sequence_length
735-
self.padding_token = config.padding_token
736-
# Make self.model = the reconstituted model (constant)
737-
self.model = best_model_found # reconstituted_model
738-
739-
def get_config(self):
740-
return {
741-
'config': self.config.get_config()
742-
}
743-
744-
@classmethod
745-
def from_config(cls, config):
746-
config_obj = CerebrosNotGPTConfig.from_config(config['config'])
747-
return cls(config=config_obj)
748-
749-
@staticmethod
750-
def apply_top_k_probs(probs, k):
751-
if k is None or k <= 0:
752-
return probs
753-
# Flatten and argsort for indices
754-
sorted_indices = tf.argsort(probs, direction='DESCENDING')
755-
keep_indices = sorted_indices[:k]
756-
mask = tf.zeros_like(probs, dtype=tf.bool)
757-
mask = tf.tensor_scatter_nd_update(mask, tf.reshape(keep_indices, (-1,1)), tf.ones((k,), dtype=tf.bool))
758-
filtered_probs = tf.where(mask, probs, tf.zeros_like(probs))
759-
# Renormalize
760-
filtered_probs = filtered_probs / tf.reduce_sum(filtered_probs)
761-
return filtered_probs
762-
763-
@staticmethod
764-
def apply_top_p_probs(probs, p):
765-
if p is None or p >= 1.0:
766-
return probs
767-
sorted_indices = tf.argsort(probs, direction='DESCENDING')
768-
sorted_probs = tf.gather(probs, sorted_indices)
769-
cumulative_probs = tf.cumsum(sorted_probs)
770-
mask = cumulative_probs <= p
771-
# Always keep at least 1 token
772-
mask = tf.concat([tf.constant([True]), mask[1:]], axis=0)
773-
keep_indices = tf.boolean_mask(sorted_indices, mask)
774-
filtered_probs = tf.where(tf.reduce_any(tf.equal(tf.range(tf.shape(probs)[0])[:,None], keep_indices), axis=1), probs, tf.zeros_like(probs))
775-
# Renormalize
776-
filtered_probs = filtered_probs / tf.reduce_sum(filtered_probs)
777-
return filtered_probs
778-
779-
780-
def generate(self,
781-
token_ids,
782-
do_sample=False,
783-
max_new_tokens=None,
784-
temperature=1.0,
785-
top_k=None,
786-
top_p=None,
787-
frequency_penalty=None,
788-
presence_penalty=None,
789-
repetition_penalty=None):
790-
"""
791-
Generate text autoregressively from token IDs.
792-
Applies filtering in sequence: penalties -> temperature -> top-k -> top-p
793-
"""
794-
# Convert token_ids to list if it's not already
795-
if not isinstance(token_ids, list):
796-
token_ids = list(token_ids)
797-
798-
# Determine the actual maximum number of new tokens
799-
if max_new_tokens is None:
800-
max_new_tokens = self.max_sequence_length - len(token_ids)
801-
else:
802-
max_new_tokens = min(max_new_tokens, self.max_sequence_length - len(token_ids))
803-
804-
# Initialize the generated tokens list
805-
generated_tokens = []
806-
current_tokens = token_ids.copy()
807-
808-
# Autoregressive generation loop
809-
for _ in range(max_new_tokens):
810-
# Pad or truncate to max_sequence_length
811-
if len(current_tokens) > self.max_sequence_length:
812-
input_tokens = current_tokens[-self.max_sequence_length:]
813-
else:
814-
padding_needed = self.max_sequence_length - len(current_tokens)
815-
input_tokens = current_tokens + [self.padding_token] * padding_needed
816-
817-
# Convert to tensor and get model prediction
818-
input_tensor = tf.constant([input_tokens], dtype=tf.int32)
819-
probs_nested = self.model(input_tensor)
820-
probs = probs_nested[0] # Already softmax probabilities (NOT logits as comment says)
821-
logits = tf.math.log(probs + 10 ** -20) # Convert to logits for penalty application
822-
823-
if do_sample:
824-
# Apply repetition/frequency/presence penalties to logits
825-
if frequency_penalty is not None or presence_penalty is not None:
826-
# Collect token counts from current_tokens
827-
token_counts = {}
828-
for t in current_tokens:
829-
token_counts[t] = token_counts.get(t, 0) + 1
830-
831-
# Prepare penalty tensor
832-
vocab_size = tf.shape(logits)[0]
833-
penalties = tf.zeros_like(logits)
834-
835-
for token_id, count in token_counts.items():
836-
if token_id >= vocab_size:
837-
continue
838-
penalty = 0.0
839-
if presence_penalty is not None:
840-
penalty += presence_penalty
841-
if frequency_penalty is not None:
842-
penalty += frequency_penalty * count
843-
844-
penalties = tf.tensor_scatter_nd_add(
845-
penalties,
846-
[[token_id]],
847-
[penalty]
848-
)
849-
850-
# Subtract penalties from logits
851-
logits = logits - penalties
852-
853-
# Apply repetition penalty (standard approach)
854-
if repetition_penalty is not None and repetition_penalty != 1.0:
855-
# Collect unique tokens that have appeared
856-
unique_tokens = list(set(current_tokens))
857-
vocab_size = tf.shape(logits)[0]
858-
859-
for token_id in unique_tokens:
860-
if token_id < vocab_size:
861-
# Divide logits of repeated tokens by penalty
862-
logits = tf.tensor_scatter_nd_update(
863-
logits,
864-
[[token_id]],
865-
[logits[token_id] / repetition_penalty]
866-
)
867-
868-
# Apply temperature
869-
if temperature != 1.0:
870-
logits = logits / temperature
871-
872-
# Convert to probabilities
873-
probs = tf.nn.softmax(logits)
874-
875-
# Apply top-k filtering (if specified)
876-
if top_k is not None and top_k > 0:
877-
k = min(top_k, tf.shape(probs)[0])
878-
# Get top-k values and indices
879-
top_k_values, top_k_indices = tf.nn.top_k(probs, k=k, sorted=False)
880-
# Create mask for top-k positions
881-
top_k_mask = tf.scatter_nd(
882-
tf.expand_dims(top_k_indices, 1),
883-
tf.ones_like(top_k_values, dtype=tf.bool),
884-
tf.shape(probs)
885-
)
886-
# Zero out non-top-k probabilities
887-
probs = tf.where(top_k_mask, probs, tf.zeros_like(probs))
888-
# Renormalize
889-
probs = probs / tf.reduce_sum(probs)
890-
print(f">>> After top_k: {tf.shape(probs)} shape, {tf.reduce_sum(tf.cast(probs > 1e-8, tf.int32))} non-zero probs")
891-
892-
# Apply top-p filtering (if specified)
893-
if top_p is not None and top_p < 1.0:
894-
# Sort probabilities in descending order
895-
sorted_indices = tf.argsort(probs, direction='DESCENDING')
896-
sorted_probs = tf.gather(probs, sorted_indices)
897-
cumulative_probs = tf.cumsum(sorted_probs)
898-
# Create mask for top-p
899-
mask = cumulative_probs <= top_p
900-
# Always keep at least one token
901-
mask = tf.concat([tf.constant([True]), mask[1:]], axis=0)
902-
# Get indices to keep
903-
keep_indices = tf.boolean_mask(sorted_indices, mask)
904-
# Create mask for original indices
905-
filter_mask = tf.scatter_nd(
906-
tf.expand_dims(keep_indices, 1),
907-
tf.ones_like(keep_indices, dtype=tf.bool),
908-
tf.shape(probs)
909-
)
910-
# Apply mask and renormalize
911-
probs = tf.where(filter_mask, probs, tf.zeros_like(probs))
912-
probs = probs / tf.reduce_sum(probs)
913-
print(f">>> After top_p: {tf.shape(probs)} shape, {tf.reduce_sum(tf.cast(probs > 1e-8, tf.int32))} non-zero probs")
914-
915-
# Sample from the final filtered distribution
916-
# Get non-zero indices and their probabilities
917-
non_zero_mask = probs > 1e-8
918-
if tf.reduce_any(non_zero_mask):
919-
filtered_indices = tf.where(non_zero_mask)[:, 0] # Get indices
920-
filtered_probs = tf.boolean_mask(probs, non_zero_mask) # Get probabilities
921-
# Sample
922-
sampled_local_index = tf.random.categorical(tf.math.log(filtered_probs)[None, :], 1)[0, 0]
923-
# Map back to vocabulary index
924-
next_token_id = int(filtered_indices[sampled_local_index].numpy())
925-
else:
926-
# Fallback if all probabilities are zero
927-
warn("Token sampling had to revert to greedy sampling, because no probs had a value > 0, unexpected")
928-
next_token_id = int(tf.argmax(probs, axis=-1).numpy())
929-
930-
else:
931-
# Greedy sampling (argmax) - apply repetition penalty if needed
932-
if repetition_penalty is not None and repetition_penalty != 1.0:
933-
unique_tokens = list(set(current_tokens))
934-
vocab_size = tf.shape(logits)[0]
935-
for token_id in unique_tokens:
936-
if token_id < vocab_size:
937-
logits = tf.tensor_scatter_nd_update(
938-
logits,
939-
[[token_id]],
940-
[logits[token_id] / repetition_penalty]
941-
)
942-
943-
next_token_id = int(tf.argmax(logits, axis=-1).numpy())
944-
945-
# Check for termination condition
946-
if next_token_id == self.padding_token:
947-
break
948-
949-
# Add to generated tokens and update current tokens
950-
generated_tokens.append(int(next_token_id))
951-
current_tokens.append(int(next_token_id))
952-
953-
# Check if we've reached max sequence length
954-
if len(current_tokens) >= self.max_sequence_length:
955-
break
956-
957-
return token_ids + generated_tokens
958687

688+
959689

960-
def call(self, inputs):
961-
# This is just for compatibility, the main logic is in generate()
962-
return self.model(inputs)
963690

964691
# Replace the generation code block with this:
965692

0 commit comments

Comments
 (0)