Skip to content

Commit 4105067

Browse files
Update llm_utils.py
Added model / environment - agnostic CerebrosNotGPT to llmutils.
1 parent fc6cac7 commit 4105067

File tree

1 file changed

+263
-0
lines changed

1 file changed

+263
-0
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,266 @@ def reset_state(self):
310310
self.total_crossentropy.assign(0.0)
311311
self.count.assign(0.0)
312312

313+
@tf.keras.utils.register_keras_serializable()
314+
class CerebrosNotGPTConfig:
315+
def __init__(self, max_sequence_length=1536, padding_token=None):
316+
self.max_sequence_length = max_sequence_length
317+
self.padding_token = padding_token
318+
319+
def get_config(self):
320+
return {
321+
'max_sequence_length': self.max_sequence_length,
322+
'padding_token': self.padding_token
323+
# NO model_0 here!
324+
}
325+
326+
@classmethod
327+
def from_config(cls, config):
328+
return cls(**config) # No model_0 to handle
329+
330+
331+
@tf.keras.utils.register_keras_serializable()
332+
class CerebrosNotGPT(tf.keras.Model):
333+
def __init__(self, config, model_0=None, **kwargs):
334+
super().__init__(**kwargs)
335+
self.config = config
336+
self.max_sequence_length = config.max_sequence_length
337+
self.padding_token = config.padding_token
338+
339+
# Handle model assignment
340+
if model_0 is not None:
341+
self.model = model_0
342+
else:
343+
# This branch is for deserialization - Keras will restore self.model automatically
344+
# if it was a proper Keras layer/model that was added via self.model = some_keras_model
345+
pass
346+
347+
def get_config(self):
348+
return {
349+
'config': self.config.get_config()
350+
# NO model reference here!
351+
}
352+
353+
@classmethod
354+
def from_config(cls, config):
355+
config_obj = CerebrosNotGPTConfig.from_config(config['config'])
356+
return cls(config=config_obj) # Keras will handle model restoration
357+
358+
def call(self, inputs):
359+
return self.model(inputs)
360+
361+
@staticmethod
362+
def apply_top_k_probs(probs, k):
363+
if k is None or k <= 0:
364+
return probs
365+
# Flatten and argsort for indices
366+
sorted_indices = tf.argsort(probs, direction='DESCENDING')
367+
keep_indices = sorted_indices[:k]
368+
mask = tf.zeros_like(probs, dtype=tf.bool)
369+
mask = tf.tensor_scatter_nd_update(mask, tf.reshape(keep_indices, (-1, 1)),
370+
tf.ones((k,), dtype=tf.bool))
371+
filtered_probs = tf.where(mask, probs, tf.zeros_like(probs))
372+
# Renormalize
373+
filtered_probs = filtered_probs / tf.reduce_sum(filtered_probs)
374+
return filtered_probs
375+
376+
@staticmethod
377+
def apply_top_p_probs(probs, p):
378+
if p is None or p >= 1.0:
379+
return probs
380+
sorted_indices = tf.argsort(probs, direction='DESCENDING')
381+
sorted_probs = tf.gather(probs, sorted_indices)
382+
cumulative_probs = tf.cumsum(sorted_probs)
383+
mask = cumulative_probs <= p
384+
# Always keep at least 1 token
385+
mask = tf.concat([tf.constant([True]), mask[1:]], axis=0)
386+
keep_indices = tf.boolean_mask(sorted_indices, mask)
387+
filtered_probs = tf.where(
388+
tf.reduce_any(tf.equal(tf.range(tf.shape(probs)[0])[:, None], keep_indices), axis=1), probs,
389+
tf.zeros_like(probs))
390+
# Renormalize
391+
filtered_probs = filtered_probs / tf.reduce_sum(filtered_probs)
392+
return filtered_probs
393+
394+
def generate(self,
395+
token_ids,
396+
do_sample=False,
397+
max_new_tokens=None,
398+
temperature=1.0,
399+
top_k=None,
400+
top_p=None,
401+
frequency_penalty=None,
402+
presence_penalty=None,
403+
repetition_penalty=None):
404+
"""
405+
Generate text autoregressively from token IDs.
406+
Applies filtering in sequence: penalties -> temperature -> top-k -> top-p
407+
"""
408+
# Convert token_ids to list if it's not already
409+
if not isinstance(token_ids, list):
410+
token_ids = list(token_ids)
411+
412+
# Determine the actual maximum number of new tokens
413+
if max_new_tokens is None:
414+
max_new_tokens = self.max_sequence_length - len(token_ids)
415+
else:
416+
max_new_tokens = min(max_new_tokens, self.max_sequence_length - len(token_ids))
417+
418+
# Initialize the generated tokens list
419+
generated_tokens = []
420+
current_tokens = token_ids.copy()
421+
422+
# Autoregressive generation loop
423+
for _ in range(max_new_tokens):
424+
# Pad or truncate to max_sequence_length
425+
if len(current_tokens) > self.max_sequence_length:
426+
input_tokens = current_tokens[-self.max_sequence_length:]
427+
else:
428+
padding_needed = self.max_sequence_length - len(current_tokens)
429+
input_tokens = current_tokens + [self.padding_token] * padding_needed
430+
431+
# Convert to tensor and get model prediction
432+
input_tensor = tf.constant([input_tokens], dtype=tf.int32)
433+
probs_nested = self.model(input_tensor)
434+
probs = probs_nested[0] # Already softmax probabilities (NOT logits as comment says)
435+
logits = tf.math.log(probs + 10 ** -20) # Convert to logits for penalty application
436+
437+
if do_sample:
438+
# Apply repetition/frequency/presence penalties to logits
439+
if frequency_penalty is not None or presence_penalty is not None:
440+
# Collect token counts from current_tokens
441+
token_counts = {}
442+
for t in current_tokens:
443+
token_counts[t] = token_counts.get(t, 0) + 1
444+
445+
# Prepare penalty tensor
446+
vocab_size = tf.shape(logits)[0]
447+
penalties = tf.zeros_like(logits)
448+
449+
for token_id, count in token_counts.items():
450+
if token_id >= vocab_size:
451+
continue
452+
penalty = 0.0
453+
if presence_penalty is not None:
454+
penalty += presence_penalty
455+
if frequency_penalty is not None:
456+
penalty += frequency_penalty * count
457+
458+
penalties = tf.tensor_scatter_nd_add(
459+
penalties,
460+
[[token_id]],
461+
[penalty]
462+
)
463+
464+
# Subtract penalties from logits
465+
logits = logits - penalties
466+
467+
# Apply repetition penalty (standard approach)
468+
if repetition_penalty is not None and repetition_penalty != 1.0:
469+
# Collect unique tokens that have appeared
470+
unique_tokens = list(set(current_tokens))
471+
vocab_size = tf.shape(logits)[0]
472+
473+
for token_id in unique_tokens:
474+
if token_id < vocab_size:
475+
# Divide logits of repeated tokens by penalty
476+
logits = tf.tensor_scatter_nd_update(
477+
logits,
478+
[[token_id]],
479+
[logits[token_id] / repetition_penalty]
480+
)
481+
482+
# Apply temperature
483+
if temperature != 1.0:
484+
logits = logits / temperature
485+
486+
# Convert to probabilities
487+
probs = tf.nn.softmax(logits)
488+
489+
# Apply top-k filtering (if specified)
490+
if top_k is not None and top_k > 0:
491+
k = min(top_k, tf.shape(probs)[0])
492+
# Get top-k values and indices
493+
top_k_values, top_k_indices = tf.nn.top_k(probs, k=k, sorted=False)
494+
# Create mask for top-k positions
495+
top_k_mask = tf.scatter_nd(
496+
tf.expand_dims(top_k_indices, 1),
497+
tf.ones_like(top_k_values, dtype=tf.bool),
498+
tf.shape(probs)
499+
)
500+
# Zero out non-top-k probabilities
501+
probs = tf.where(top_k_mask, probs, tf.zeros_like(probs))
502+
# Renormalize
503+
probs = probs / tf.reduce_sum(probs)
504+
print(
505+
f">>> After top_k: {tf.shape(probs)} shape, {tf.reduce_sum(tf.cast(probs > 1e-8, tf.int32))} non-zero probs")
506+
507+
# Apply top-p filtering (if specified)
508+
if top_p is not None and top_p < 1.0:
509+
# Sort probabilities in descending order
510+
sorted_indices = tf.argsort(probs, direction='DESCENDING')
511+
sorted_probs = tf.gather(probs, sorted_indices)
512+
cumulative_probs = tf.cumsum(sorted_probs)
513+
# Create mask for top-p
514+
mask = cumulative_probs <= top_p
515+
# Always keep at least one token
516+
mask = tf.concat([tf.constant([True]), mask[1:]], axis=0)
517+
# Get indices to keep
518+
keep_indices = tf.boolean_mask(sorted_indices, mask)
519+
# Create mask for original indices
520+
filter_mask = tf.scatter_nd(
521+
tf.expand_dims(keep_indices, 1),
522+
tf.ones_like(keep_indices, dtype=tf.bool),
523+
tf.shape(probs)
524+
)
525+
# Apply mask and renormalize
526+
probs = tf.where(filter_mask, probs, tf.zeros_like(probs))
527+
probs = probs / tf.reduce_sum(probs)
528+
print(
529+
f">>> After top_p: {tf.shape(probs)} shape, {tf.reduce_sum(tf.cast(probs > 1e-8, tf.int32))} non-zero probs")
530+
531+
# Sample from the final filtered distribution
532+
# Get non-zero indices and their probabilities
533+
non_zero_mask = probs > 1e-8
534+
if tf.reduce_any(non_zero_mask):
535+
filtered_indices = tf.where(non_zero_mask)[:, 0] # Get indices
536+
filtered_probs = tf.boolean_mask(probs, non_zero_mask) # Get probabilities
537+
# Sample
538+
sampled_local_index = tf.random.categorical(tf.math.log(filtered_probs)[None, :], 1)[0, 0]
539+
# Map back to vocabulary index
540+
next_token_id = int(filtered_indices[sampled_local_index].numpy())
541+
else:
542+
# Fallback if all probabilities are zero
543+
warn(
544+
"Token sampling had to revert to greedy sampling, because no probs had a value > 0, unexpected")
545+
next_token_id = int(tf.argmax(probs, axis=-1).numpy())
546+
547+
else:
548+
# Greedy sampling (argmax) - apply repetition penalty if needed
549+
if repetition_penalty is not None and repetition_penalty != 1.0:
550+
unique_tokens = list(set(current_tokens))
551+
vocab_size = tf.shape(logits)[0]
552+
for token_id in unique_tokens:
553+
if token_id < vocab_size:
554+
logits = tf.tensor_scatter_nd_update(
555+
logits,
556+
[[token_id]],
557+
[logits[token_id] / repetition_penalty]
558+
)
559+
560+
next_token_id = int(tf.argmax(logits, axis=-1).numpy())
561+
562+
# Check for termination condition
563+
if next_token_id == self.padding_token:
564+
break
565+
566+
# Add to generated tokens and update current tokens
567+
generated_tokens.append(int(next_token_id))
568+
current_tokens.append(int(next_token_id))
569+
570+
# Check if we've reached max sequence length
571+
if len(current_tokens) >= self.max_sequence_length:
572+
break
573+
574+
return token_ids + generated_tokens
575+

0 commit comments

Comments
 (0)