@@ -786,7 +786,7 @@ def generate(self, token_ids, do_sample=False, max_new_tokens=None):
786786 current_tokens = token_ids .copy ()
787787
788788 # Autoregressive generation loop
789- temp_gen_count = 0 # <--------<< Debug code to remove later
789+ # temp_gen_count = 0 # <--------<< Debug code to remove later
790790 for _ in range (max_new_tokens ):
791791 # Pad or truncate to max_sequence_length (CORRECTED PADDING LOGIC)
792792 if len (current_tokens ) > self .max_sequence_length :
@@ -809,11 +809,11 @@ def generate(self, token_ids, do_sample=False, max_new_tokens=None):
809809 # Greedy sampling (argmax)
810810 next_token_id = int (tf .argmax (logits [0 ], axis = - 1 ).numpy ())
811811 # Debug code to removel later
812- print (f"Generating { temp_gen_count } " )
813- print (f"... next_token_id: { next_token_id } " )
814- next_word = tokenizer .decode (next_token_id )
815- print (f"Next decoded word: { next_word } " )
816- temp_gen_count += 1
812+ # print(f"Generating {temp_gen_count}")
813+ # print(f"... next_token_id: {next_token_id}")
814+ # next_word = tokenizer.decode(next_token_id)
815+ # print(f"Next decoded word: {next_word}")
816+ # temp_gen_count +=1
817817
818818 # Check for termination condition
819819 if next_token_id == self .padding_token :
@@ -854,16 +854,38 @@ def call(self, inputs):
854854
855855print ("########### BEFORE SEARIALIZING THE GENERATIVE MODEL" )
856856
857+ def complete_text (text ):
858+ input_ids = tokenizer (
859+ half_sample ,
860+ add_special_tokens = False
861+ )['input_ids' ]
862+
863+ generated_tokens = generator .generate (
864+ token_ids = token_ids , # Just the actual tokens, no padding
865+ do_sample = False ,
866+ max_new_tokens = 40
867+ )
868+ generated_text = \
869+ tokenizer .decode (generated_tokens ).replace (text , "" )
870+ rerurn generated_text
871+
872+ test_text = "I saw the sun and it was"
873+ response = complete_text (test_text )
874+
875+ print (f"I ask the generator: { test_text } ... It responds:" )
876+ print (response )
877+
857878counter = 0
858879for sample in non_instruct_samples :
859- half_sample_len = int (np .ceil (len (sample ) / 2 ))
860- half_sample = sample [:half_sample_len ]
880+
861881
862882 # Tokenize the text without padding first to get actual tokens
863- half_sample_tokenized = tokenizer (
864- half_sample ,
883+ sample_tokenized = tokenizer (
884+ sample ,
865885 add_special_tokens = False
866886 )['input_ids' ]
887+ half_index = int (np .ceil (len (sample_tokenized ) * 0.5 ))
888+ half_sample_tokenized = sample_tokenized [:half_index ]
867889
868890 # Convert to Python list of integers
869891 if hasattr (half_sample_tokenized , 'numpy' ):
@@ -883,62 +905,10 @@ def call(self, inputs):
883905
884906 # Decode the result
885907 full_generated_text = tokenizer .decode (generated_tokens , skip_special_tokens = False )
886- print (f"PROMPT number { counter } : { half_sample } ; RESPONSE: { full_generated_text } " )
908+ print (f"PROMPT number { counter } : { half_sample } ; RESPONSE: { full_generated_text . replace () } " )
887909 counter += 1
888910
889911
890-
891-
892-
893- # # Process ALL original samples from data - REAL WORLD USAGE
894- # generated_texts = []
895- # for i, original_text in enumerate(data[:5]): # Process first 5 samples
896- # print(f"\nProcessing sample {i+1}...")
897-
898- # # Extract prompt part (everything up to and including </prompt>)
899- # if '</prompt>' in original_text:
900- # prompt_part = original_text.split('</prompt>')[0] + '</prompt>'
901- # else:
902- # prompt_part = original_text
903-
904- # # Tokenize the prompt part
905- # tokenized = tokenizer(
906- # prompt_part,
907- # add_special_tokens=False, # We handle special tokens manually
908- # return_tensors=None # Return lists, not tensors
909- # )
910- # prompt_tokens = tokenized['input_ids']
911-
912- # print(f"Original prompt: {prompt_part[:100]}...")
913- # print(f"Tokenized prompt length: {len(prompt_tokens)} tokens")
914-
915- # # Generate tokens using the wrapper class - REAL WORLD USAGE
916- # generated_tokens = generator.generate(
917- # token_ids=prompt_tokens,
918- # do_sample=False,
919- # max_new_tokens=100
920- # )
921-
922- # # Decode the full generated text
923- # full_generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)
924-
925- # # Extract just the newly generated part (after the prompt)
926- # generated_part = full_generated_text[len(prompt_part):]
927-
928- # generated_texts.append((prompt_part, generated_part))
929-
930- # print(f"Generated response: {generated_part}...")
931-
932- # # Display results with proper formatting
933- # print("\n" + "="*50)
934- # print("FINAL GENERATED RESULTS")
935- # print("="*50)
936-
937- # for idx, (original_prompt, generated_response) in enumerate(generated_texts):
938- # print(f"\nSample {idx+1}:")
939- # print(f"Prompt:{original_prompt}")
940- # print(f"Response: {generated_response}")
941-
942912# Save the model
943913model_save_path = f"{ TIME } _cerebros-autoregressive-text-generator.keras"
944914generator .save (model_save_path )
@@ -988,96 +958,9 @@ def call(self, inputs):
988958 counter += 1
989959
990960
991- # # Test with all original data samples - REAL WORLD DEMO (reconstituted)
992- # print("\n" + "="*50)
993- # print("GENERATED TEXT SAMPLES FROM ALL DATA - REAL WORLD USAGE (reconstituted)")
994- # print("="*50)
995-
996- # generated_texts_all = []
997- # for i, text in enumerate(data):
998- # # Extract prompt part (everything up to and including </prompt>)
999- # if '</prompt>' in text:
1000- # prompt_text = text.split('</prompt>')[0] + '</prompt>'
1001- # else:
1002- # prompt_text = text
1003-
1004- # # Tokenize the prompt part for model input
1005- # tokenized = tokenizer(
1006- # prompt_text,
1007- # max_length=MAX_SEQ_LENGTH,
1008- # padding='max_length',
1009- # truncation=True,
1010- # add_special_tokens=False
1011- # )
1012- # token_ids = tokenized['input_ids']
1013-
1014- # # Generate using the reconstituted model
1015- # generated_token_ids = reconstituted_generator.generate(
1016- # token_ids=token_ids,
1017- # do_sample=False,
1018- # max_new_tokens=100
1019- # )
1020-
1021- # # Decode generated text
1022- # generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=False)
1023- # generated_texts_all.append(generated_text)
1024-
1025-
1026- # print(f"\nSample {i+1}:")
1027- # print(f"Prompt: {prompt_text}")
1028- # print(f"Generated: {generated_text}")
1029- # # [len(prompt_text):][:200]}...")
1030-
1031961print ("\n All samples processed with reconstituted model!" )
1032962
1033963
1034- # # Test with all original data samples
1035- # print("\n" + "="*50)
1036- # print("GENERATED TEXT SAMPLES FROM ALL DATA")
1037- # print("="*50)
1038-
1039- # generated_texts_all = []
1040- # for i, text in enumerate(data[:3]): # Process first 3 for demo
1041- # # Split such that everything before </prompt> or the entire text if </prompt> is not present
1042- # if '</prompt>' in text:
1043- # prompt_text = text.split('</prompt>')[0] + '</prompt>'
1044- # else:
1045- # prompt_text = text
1046-
1047- # # Tokenize with proper padding
1048- # tokenized = tokenizer(
1049- # prompt_text,
1050- # max_length=MAX_SEQ_LENGTH,
1051- # padding='max_length',
1052- # truncation=True,
1053- # add_special_tokens=False
1054- # )
1055- # token_ids = tokenized['input_ids']
1056-
1057- # # Generate using the reconstituted model
1058- # generated_token_ids = reconstituted_generator.generate(
1059- # token_ids=token_ids,
1060- # do_sample=False,
1061- # max_new_tokens=100
1062- # )
1063-
1064- # # Decode generated text
1065- # generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=False)
1066- # generated_texts_all.append(generated_text)
1067-
1068- # # Extract and print prompt for display
1069- # if '<prompt>' in text and '</prompt>' in text:
1070- # display_prompt = text.split('<prompt>')[1].split('</prompt>')[0]
1071- # else:
1072- # display_prompt = text[:100] + "..." if len(text) > 100 else text
1073-
1074- # print(f"\nSample {i+1}:")
1075- # print(f"Prompt: {text}")
1076- # print(f"Generated: {generated_text}") # [len(prompt_text):][:200]}...")
1077-
1078- print ("\n All samples processed!" )
1079-
1080-
1081964
1082965## Model validation
1083966print ("Validation" )
@@ -1086,6 +969,7 @@ def call(self, inputs):
1086969 metrics = ['accuracy' ]
1087970)
1088971
972+
1089973results = reconstituted_model .evaluate (x_test_packaged , y_test_packaged )
1090974print ("Test loss:" , results [0 ])
1091975print ("Test accuracy:" , results [- 1 ])
0 commit comments