5050// ----------------------------------------------------------------------------------------
5151
5252// We treat each character as a token ID in [0..255].
53- static const int MAX_TOKEN_ID = 255 ;
54- static const int PAD_TOKEN = 256 ; // an extra "pad" token if needed
53+ const int MAX_TOKEN_ID = 255 ;
54+ const int PAD_TOKEN = 256 ; // an extra "pad" token if needed
5555
5656// For simplicity, we assume each line from shakespeare_text is appended, ignoring them.
57- static std::vector<int > char_based_tokenize (const std::string& text)
57+ std::vector<int > char_based_tokenize (const std::string& text)
5858{
5959 std::vector<int > tokens;
6060 tokens.reserve (text.size ());
61- for (unsigned char c : text)
61+ for (const int c : text)
6262 {
63- tokens.push_back (std::min< int > (c, MAX_TOKEN_ID));
63+ tokens.push_back (std::min (c, MAX_TOKEN_ID));
6464 }
6565 return tokens;
6666}
@@ -108,44 +108,18 @@ int main(int argc, char** argv)
108108
109109 if (parser.number_of_arguments () == 0 && !parser.option (" train" ) && !parser.option (" generate" ))
110110 {
111- std::cout << " Usage:\n "
112- << " --train : Train a small transformer model on the Shakespeare text\n "
113- << " --generate : Generate text from a trained model using a prompt\n "
114- << " --learning-rate <value> : Set the learning rate for training (default: 1e-4)\n "
115- << " --batch-size <value> : Set the mini-batch size for training (default: 64)\n "
116- << " --generation-length <value> : Set the length of generated text (default: 400)\n "
117- << " --alpha <value> : Set the initial learning rate for Adam optimizer (default: 0.004)\n "
118- << " --beta1 <value> : Set the decay rate for the first moment estimate (default: 0.9)\n "
119- << " --beta2 <value> : Set the decay rate for the second moment estimate (default: 0.999)\n "
120- << " --max-samples <value> : Set the maximum number of training samples (default: 50000)\n "
121- << " --shuffle : Shuffle training sequences and labels before training (default: false)\n " ;
111+ parser.print_options ();
122112 return 0 ;
123113 }
124114
125115 // Default values
126- double learning_rate = 1e-4 ;
127- long batch_size = 64 ;
128- int generation_length = 400 ;
129- double alpha = 0.004 ; // Initial learning rate for Adam
130- double beta1 = 0.9 ; // Decay rate for the first moment estimate
131- double beta2 = 0.999 ; // Decay rate for the second moment estimate
132- size_t max_samples = 50000 ; // Default maximum number of training samples
133-
134- // Override defaults if options are provided
135- if (parser.option (" learning-rate" ))
136- learning_rate = std::stod (parser.option (" learning-rate" ).argument ());
137- if (parser.option (" batch-size" ))
138- batch_size = std::stol (parser.option (" batch-size" ).argument ());
139- if (parser.option (" generation-length" ))
140- generation_length = std::stoi (parser.option (" generation-length" ).argument ());
141- if (parser.option (" alpha" ))
142- alpha = std::stod (parser.option (" alpha" ).argument ());
143- if (parser.option (" beta1" ))
144- beta1 = std::stod (parser.option (" beta1" ).argument ());
145- if (parser.option (" beta2" ))
146- beta2 = std::stod (parser.option (" beta2" ).argument ());
147- if (parser.option (" max-samples" ))
148- max_samples = std::stoul (parser.option (" max-samples" ).argument ());
116+ const double learning_rate = get_option (parser, " learning-rate" , 1e-4 );
117+ const long batch_size = get_option (parser, " batch-size" , 64 );
118+ const int generation_length = get_option (parser, " generation-length" , 400 );
119+ const double alpha = get_option (parser, " alpha" , 0.004 ); // Initial learning rate for Adam
120+ const double beta1 = get_option (parser, " beta1" , 0.9 ); // Decay rate for the first moment estimate
121+ const double beta2 = get_option (parser, " beta2" , 0.999 ); // Decay rate for the second moment estimate
122+ const size_t max_samples = get_option (parser, " max-samples" ,50000 ); // Default maximum number of training samples
149123
150124 // We define a minimal config for demonstration
151125 const long vocab_size = 257 ; // 0..255 for chars + 1 pad token
@@ -297,7 +271,7 @@ int main(int argc, char** argv)
297271 prompt_text.erase (prompt_text.begin () + max_seq_len, prompt_text.end ());
298272
299273 // Convert prompt to a token sequence
300- auto prompt_tokens = char_based_tokenize (prompt_text);
274+ const auto prompt_tokens = char_based_tokenize (prompt_text);
301275
302276 // Put into a dlib matrix
303277 dlib::matrix<int , 0 , 1 > input_seq (max_seq_len, 1 );
@@ -310,37 +284,21 @@ int main(int argc, char** argv)
310284 input_seq (i, 0 ) = PAD_TOKEN;
311285 }
312286
313- std::cout << " Initial prompt:\n " << prompt_text << " (...)\n\n Generated text:\n " << prompt_text;
287+ std::cout << " \n Initial prompt:\n " << prompt_text << " (...)\n \n\n Generated text:\n " << prompt_text;
314288
315289 // 3) Generate new text
316290 // We'll predict one character at a time, then shift the window
317- // until the total length is at least generation_length and we encounter two newlines.
318- std::string generated_text = prompt_text;
319- bool stop_generation = false ;
320-
321- while (generated_text.size () < (size_t )generation_length || !stop_generation)
291+ for (int i = 0 ; i < generation_length; ++i)
322292 {
323- unsigned long next_char = net (input_seq); // single inference
324-
325- // Append the generated character to the text
326- generated_text += (char )(std::min<unsigned long >(next_char, MAX_TOKEN_ID));
293+ const int next_char = net (input_seq); // single inference
327294
328295 // Print the generated character
329- std::cout << ( char ) (std::min< unsigned long > (next_char, MAX_TOKEN_ID));
296+ std::cout << static_cast < char > (std::min (next_char, MAX_TOKEN_ID)) << std::flush ;
330297
331298 // Shift left by 1
332299 for (long i = 0 ; i < max_seq_len - 1 ; ++i)
333300 input_seq (i, 0 ) = input_seq (i + 1 , 0 );
334- input_seq (max_seq_len - 1 , 0 ) = (int )std::min<unsigned long >(next_char, MAX_TOKEN_ID);
335-
336- // Check if the last two characters are newlines
337- if (generated_text.size () >= 2 &&
338- generated_text[generated_text.size () - 1 ] == ' \n ' &&
339- generated_text[generated_text.size () - 2 ] == ' \n ' )
340- {
341- // Stop generation if the minimum length is reached
342- if (generated_text.size () >= (size_t )generation_length) stop_generation = true ;
343- }
301+ input_seq (max_seq_len - 1 , 0 ) = std::min (next_char, MAX_TOKEN_ID);
344302 }
345303
346304 std::cout << " \n\n (end of generation)\n " ;
@@ -391,4 +349,4 @@ int main(int argc, char** argv)
391349 * > QUEEN ELIZABETH:
392350 * > I go. Write to me very shortly.
393351 * > And you shall understand from me her mind.
394- */
352+ */
0 commit comments