Skip to content

Commit 94c5b82

Browse files
Update generative-proof-of-concept-CPU-preprocessing-in-memory.py
Attempt to add sampling.
1 parent a44878b commit 94c5b82

File tree

1 file changed

+29
-51
lines changed

1 file changed

+29
-51
lines changed

generative-proof-of-concept-CPU-preprocessing-in-memory.py

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -789,62 +789,34 @@ def from_config(cls, config):
789789
config_obj = CerebrosNotGPTConfig.from_config(config['config'])
790790
return cls(config=config_obj)
791791

792-
def generate(self, token_ids, do_sample=False, max_new_tokens=None):
793-
"""
794-
Generate text autoregressively from token IDs.
792+
def generate(self, token_ids, do_sample=False, max_new_tokens=None, temperature=1.0, top_k=None, top_p=None):
793+
# (init code as existing)
795794

796-
Args:
797-
token_ids: Iterable of integers representing token IDs
798-
do_sample: Boolean, if True use sampling, if False use greedy argmax
799-
max_new_tokens: Maximum number of new tokens to generate
800-
801-
Returns:
802-
List of token IDs including original tokens and generated tokens
803-
"""
804-
# Convert token_ids to list if it's not already
805-
if not isinstance(token_ids, list):
806-
token_ids = list(token_ids)
807-
808-
# Determine the actual maximum number of new tokens
809-
if max_new_tokens is None:
810-
max_new_tokens = self.max_sequence_length - len(token_ids)
811-
else:
812-
max_new_tokens = min(max_new_tokens, self.max_sequence_length - len(token_ids))
813-
814-
# Initialize the generated tokens list
815-
generated_tokens = []
816-
current_tokens = token_ids.copy()
817-
818-
# Autoregressive generation loop
819-
# temp_gen_count = 0 # <--------<< Debug code to remove later
820795
for _ in range(max_new_tokens):
821-
# Pad or truncate to max_sequence_length (CORRECTED PADDING LOGIC)
822-
if len(current_tokens) > self.max_sequence_length:
823-
input_tokens = current_tokens[:self.max_sequence_length]
824-
else:
825-
# Manual padding with padding token
826-
padding_needed = self.max_sequence_length - len(current_tokens)
827-
input_tokens = current_tokens + [self.padding_token] * padding_needed
828-
829-
# Convert to tensor and get model prediction
796+
# (padding code as existing)
830797
input_tensor = tf.constant([input_tokens], dtype=tf.int32)
831-
logits = self.model(input_tensor) # Shape: (batch_size, VOCABULARY_SIZE)
798+
logits = self.model(input_tensor)
799+
800+
# Apply temperature scaling (logits->probs because your model returns softmax)
801+
probs = logits[0] # logits[0] is already softmax
832802

833-
# Get next token based on sampling strategy
834803
if do_sample:
835-
# Sample from the distribution
836-
# probabilities = tf.nn.softmax(logits[0], axis=-1) # Model already applies softmax
837-
next_token_id = tf.random.categorical(tf.math.log(logits[0])[None, :], 1)[0, 0].numpy()
804+
# 1. Temperature: convert back to logits, scale, resoftmax
805+
if temperature != 1.0:
806+
temp_logits = tf.math.log(probs + 1e-20) / temperature
807+
probs = tf.nn.softmax(temp_logits)
808+
# 2. Top-k filtering
809+
if top_k is not None and top_k > 0:
810+
probs = apply_top_k_probs(probs, top_k)
811+
# 3. Top-p filtering
812+
if top_p is not None and top_p < 1.0:
813+
probs = apply_top_p_probs(probs, top_p)
814+
# Sample
815+
next_token_id = tf.random.categorical(tf.math.log(probs[None, :]), 1)[0,0].numpy()
838816
else:
839817
# Greedy sampling (argmax)
840-
next_token_id = int(tf.argmax(logits[0], axis=-1).numpy())
841-
# Debug code to removel later
842-
# print(f"Generating {temp_gen_count}")
843-
# print(f"... next_token_id: {next_token_id}")
844-
# next_word = tokenizer.decode(next_token_id)
845-
# print(f"Next decoded word: {next_word}")
846-
# temp_gen_count +=1
847-
818+
next_token_id = int(tf.argmax(probs, axis=-1).numpy())
819+
848820
# Check for termination condition
849821
if next_token_id == self.padding_token:
850822
break
@@ -864,6 +836,8 @@ def generate(self, token_ids, do_sample=False, max_new_tokens=None):
864836
total_tokens.extend([self.padding_token] * padding_needed)
865837

866838
return total_tokens
839+
840+
867841

868842
def call(self, inputs):
869843
# This is just for compatibility, the main logic is in generate()
@@ -929,9 +903,13 @@ def complete_text(text):
929903

930904
# Now pass the list of integers to your generate method
931905
generated_tokens = generator.generate(
906+
# do_sample=False, max_new_tokens=None, temperature=1.0, top_k=None, top_p=None
932907
token_ids=token_ids, # Just the actual tokens, no padding
933-
do_sample=False,
934-
max_new_tokens=40
908+
do_sample=True,
909+
max_new_tokens=20,
910+
temperature=0.6,
911+
top_k=20,
912+
top_p=0.9,
935913
)
936914

937915
# Decode the result

0 commit comments

Comments
 (0)