@@ -655,76 +655,14 @@ def reset_state(self):
655655print ("GENERATED TEXT SAMPLES" )
656656print ("=" * 50 )
657657
658- # Get pad token id
659- pad_token_id = tokenizer .pad_token_id
660- end_prompt_token_id = tokenizer .encode ("</prompt>" , add_special_tokens = False )[0 ]
661-
662- # # Generate text for first 5 test samples (Working)
663- # generated_texts = []
664- # for i in range(min(5, len(x_test_packaged[0]))):
665- # original_input = x_test_packaged[0][i].numpy()
666-
667- # # Find the end of the prompt
668- # try:
669- # end_prompt_index = list(original_input).index(end_prompt_token_id)
670- # except ValueError:
671- # end_prompt_index = 0
672-
673- # # Extract the prompt part
674- # prompt_tokens = original_input[:end_prompt_index+1].tolist()
675-
676- # # Generate tokens sequentially
677- # generated_tokens = []
678- # current_input = prompt_tokens.copy()
679-
680- # # Generate up to 100 tokens or until pad token
681- # for _ in range(100):
682- # # Pad or truncate to MAX_SEQ_LENGTH
683- # input_tensor = tf.constant([current_input + [pad_token_id] * (MAX_SEQ_LENGTH - len(current_input))], dtype=tf.int32)
684-
685- # # Get prediction
686- # prediction = reconstituted_model(input_tensor)
687- # next_token_id = int(tf.argmax(prediction[0], axis=-1).numpy())
688-
689- # # Stop if pad token generated
690- # if next_token_id == pad_token_id:
691- # break
692-
693- # generated_tokens.append(next_token_id)
694- # current_input.append(next_token_id)
695-
696- # # Stop if we exceed max length
697- # if len(current_input) >= MAX_SEQ_LENGTH:
698- # break
699-
700- # generated_texts.append((prompt_tokens, generated_tokens))
701-
702- # # Decode and print with proper formatting
703- # for idx, (prompt_tokens, generated_tokens) in enumerate(generated_texts):
704- # # Decode prompt
705- # prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=False)
706-
707- # # Extract original prompt content
708- # if '<prompt>' in prompt_text and '</prompt>' in prompt_text:
709- # original_prompt = prompt_text.split('<prompt>')[-1].split('</prompt>')[0]
710- # else:
711- # original_prompt = prompt_text[:50] + "..." if len(prompt_text) > 50 else prompt_text
712-
713- # # Decode generated text
714- # generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False) if generated_tokens else ""
715-
716- # print(f"\nGenerated text from sample {idx+1}:")
717- # print(f"<prompt>{original_prompt}</prompt>{generated_text}")
718-
719-
720658
721659## Proper model wrapper and generation method (under development):
722660
723661print ("###### Output of the model wrapper (under development) ########### " )
724662
725663# Register the config and model wrapper as serializable
726664@tf .keras .utils .register_keras_serializable ()
727- class CerebrosAutoregressiveTextGeneratorConfig :
665+ class CerebrosNotGPTConfig :
728666 def __init__ (self , max_sequence_length = 1536 , padding_token = None ):
729667 self .max_sequence_length = max_sequence_length
730668 self .padding_token = padding_token
@@ -740,7 +678,7 @@ def from_config(cls, config):
740678 return cls (** config )
741679
742680@tf .keras .utils .register_keras_serializable ()
743- class CerebrosAutoregressiveTextGenerator (tf .keras .Model ):
681+ class CerebrosNotGPT (tf .keras .Model ):
744682 def __init__ (self , config , ** kwargs ):
745683 super ().__init__ (** kwargs )
746684 self .config = config
@@ -846,17 +784,17 @@ def call(self, inputs):
846784print ("=" * 50 )
847785
848786# Create config and generator
849- config = CerebrosAutoregressiveTextGeneratorConfig (
787+ config = CerebrosNotGPTConfig (
850788 max_sequence_length = MAX_SEQ_LENGTH ,
851789 padding_token = tokenizer .pad_token_id
852790)
853- generator = CerebrosAutoregressiveTextGenerator (config )
791+ generator = CerebrosNotGPT (config )
854792
855793print ("########### BEFORE SEARIALIZING THE GENERATIVE MODEL" )
856794
857795def complete_text (text ):
858796 input_ids = tokenizer (
859- half_sample ,
797+ text ,
860798 add_special_tokens = False
861799 )['input_ids' ]
862800
0 commit comments