4949
5050// ----------------------------------------------------------------------------------------
5151
52+ using namespace std ;
53+ using namespace dlib ;
54+
5255// We treat each character as a token ID in [0..255].
5356const int MAX_TOKEN_ID = 255 ;
5457const int PAD_TOKEN = 256 ; // an extra "pad" token if needed
@@ -66,13 +69,13 @@ std::vector<int> char_based_tokenize(const std::string& text)
6669}
6770
6871// Function to shuffle samples and labels in sync
69- void shuffle_samples_and_labels (std::vector<dlib:: matrix<int , 0 , 1 >>& samples, std::vector<unsigned long >& labels) {
72+ void shuffle_samples_and_labels (std::vector<matrix<int , 0 , 1 >>& samples, std::vector<unsigned long >& labels) {
7073 std::vector<size_t > indices (samples.size ());
7174 std::iota (indices.begin (), indices.end (), 0 ); // Fill with 0, 1, 2, ..., N-1
7275 std::shuffle (indices.begin (), indices.end (), std::default_random_engine{});
7376
7477 // Create temporary vectors to hold shuffled data
75- std::vector<dlib:: matrix<int , 0 , 1 >> shuffled_samples (samples.size ());
78+ std::vector<matrix<int , 0 , 1 >> shuffled_samples (samples.size ());
7679 std::vector<unsigned long > shuffled_labels (labels.size ());
7780
7881 // Apply the shuffle
@@ -93,15 +96,15 @@ int main(int argc, char** argv)
9396{
9497 try
9598 {
96- dlib:: command_line_parser parser;
99+ command_line_parser parser;
97100 parser.add_option (" train" , " Train a small transformer on the built-in Shakespeare text" );
98101 parser.add_option (" generate" , " Generate text from a previously trained model (needs shakespeare_prompt)" );
99102 parser.add_option (" learning-rate" , " Set the learning rate for training (default: 1e-4)" , 1 );
100103 parser.add_option (" batch-size" , " Set the mini-batch size for training (default: 64)" , 1 );
101104 parser.add_option (" generation-length" , " Set the length of generated text (default: 400)" , 1 );
102- parser.add_option (" alpha" , " Set the initial learning rate for Adam optimizer (default: 0.004)" , 1 );
103- parser.add_option (" beta1" , " Set the decay rate for the first moment estimate (default: 0.9)" , 1 );
104- parser.add_option (" beta2" , " Set the decay rate for the second moment estimate (default: 0.999)" , 1 );
105+ parser.add_option (" alpha" , " Set the weight decay for Adam optimizer (default: 0.004)" , 1 );
106+ parser.add_option (" beta1" , " Set the first moment coefficient (default: 0.9)" , 1 );
107+ parser.add_option (" beta2" , " Set the second moment coefficient (default: 0.999)" , 1 );
105108 parser.add_option (" max-samples" , " Set the maximum number of training samples (default: 50000)" , 1 );
106109 parser.add_option (" shuffle" , " Shuffle training sequences and labels before training (default: false)" );
107110 parser.parse (argc, argv);
@@ -122,7 +125,7 @@ int main(int argc, char** argv)
122125 const size_t max_samples = get_option (parser, " max-samples" ,50000 ); // Default maximum number of training samples
123126
124127 // We define a minimal config for demonstration
125- const long vocab_size = 257 ; // 0..255 for chars + 1 pad token
128+ const long vocab_size = MAX_TOKEN_ID + 1 + 1 ; // 256 for chars + 1 pad token
126129 const long num_layers = 3 ;
127130 const long num_heads = 4 ;
128131 const long embedding_dim = 64 ;
@@ -136,8 +139,8 @@ int main(int argc, char** argv)
136139 embedding_dim,
137140 max_seq_len,
138141 use_squeezing,
139- dlib:: gelu,
140- dlib:: dropout_10
142+ gelu,
143+ dropout_10
141144 >;
142145
143146 // For GPU usage (if any), set gpus = {0} for a single GPU, etc.
@@ -151,7 +154,7 @@ int main(int argc, char** argv)
151154 // ----------------------------------------------------------------------------------------
152155 if (parser.option (" train" ))
153156 {
154- std:: cout << " === TRAIN MODE ===\n " ;
157+ cout << " === TRAIN MODE ===\n " ;
155158
156159 // 1) Prepare training data (simple approach)
157160 // We will store characters from shakespeare_text into a vector
@@ -160,7 +163,7 @@ int main(int argc, char** argv)
160163 auto full_tokens = char_based_tokenize (shakespeare_text);
161164 if (full_tokens.empty ())
162165 {
163- std:: cerr << " ERROR: The Shakespeare text is empty. Please provide a valid training text.\n " ;
166+ cerr << " ERROR: The Shakespeare text is empty. Please provide a valid training text.\n " ;
164167 return 0 ;
165168 }
166169
@@ -170,18 +173,18 @@ int main(int argc, char** argv)
170173 : 0 ;
171174
172175 // Display the size of the training text and the number of sequences
173- std:: cout << " Training text size: " << full_tokens.size () << " characters\n " ;
174- std:: cout << " Maximum number of sequences: " << max_sequences << " \n " ;
176+ cout << " Training text size: " << full_tokens.size () << " characters\n " ;
177+ cout << " Maximum number of sequences: " << max_sequences << " \n " ;
175178
176179 // Check if the text is too short
177180 if (max_sequences == 0 )
178181 {
179- std:: cerr << " ERROR: The Shakespeare text is too short for training. It must contain at least "
182+ cerr << " ERROR: The Shakespeare text is too short for training. It must contain at least "
180183 << (max_seq_len + 1 ) << " characters.\n " ;
181184 return 0 ;
182185 }
183186
184- std::vector<dlib:: matrix<int , 0 , 1 >> samples;
187+ std::vector<matrix<int , 0 , 1 >> samples;
185188 std::vector<unsigned long > labels;
186189
187190 // Let's create a training set of about (N) samples from the text
@@ -190,7 +193,7 @@ int main(int argc, char** argv)
190193 const size_t N = (max_sequences < max_samples) ? max_sequences : max_samples;
191194 for (size_t start = 0 ; start < N; ++start)
192195 {
193- dlib:: matrix<int , 0 , 1 > seq (max_seq_len, 1 );
196+ matrix<int , 0 , 1 > seq (max_seq_len, 1 );
194197 for (long t = 0 ; t < max_seq_len; ++t)
195198 seq (t, 0 ) = full_tokens[start + t];
196199 samples.push_back (seq);
@@ -200,18 +203,18 @@ int main(int argc, char** argv)
200203 // Shuffle samples and labels if the --shuffle option is enabled
201204 if (parser.option (" shuffle" ))
202205 {
203- std:: cout << " Shuffling training sequences and labels...\n " ;
206+ cout << " Shuffling training sequences and labels...\n " ;
204207 shuffle_samples_and_labels (samples, labels);
205208 }
206209
207210 // 3) Construct the network in training mode
208211 using net_type = my_transformer_cfg::network_type<true >;
209212 net_type net;
210- if (dlib:: file_exists (model_file))
211- dlib:: deserialize (model_file) >> net;
213+ if (file_exists (model_file))
214+ deserialize (model_file) >> net;
212215
213216 // 4) Create dnn_trainer
214- dlib:: dnn_trainer<net_type, dlib:: adam> trainer (net, dlib:: adam (alpha, beta1, beta2), gpus);
217+ dnn_trainer<net_type, adam> trainer (net, adam (alpha, beta1, beta2), gpus);
215218 trainer.set_learning_rate (learning_rate);
216219 trainer.set_min_learning_rate (1e-6 );
217220 trainer.set_mini_batch_size (batch_size);
@@ -229,41 +232,41 @@ int main(int argc, char** argv)
229232 if (predicted[i] == labels[i])
230233 correct++;
231234 double accuracy = (double )correct / labels.size ();
232- std:: cout << " Training accuracy (on this sample set): " << accuracy << " \n " ;
235+ cout << " Training accuracy (on this sample set): " << accuracy << " \n " ;
233236
234237 // 7) Save the model
235238 net.clean ();
236- dlib:: serialize (model_file) << net;
237- std:: cout << " Model saved to " << model_file << " \n " ;
239+ serialize (model_file) << net;
240+ cout << " Model saved to " << model_file << " \n " ;
238241 }
239242
240243 // ----------------------------------------------------------------------------------------
241244 // Generate mode
242245 // ----------------------------------------------------------------------------------------
243246 if (parser.option (" generate" ))
244247 {
245- std:: cout << " === GENERATE MODE ===\n " ;
248+ cout << " === GENERATE MODE ===\n " ;
246249 // 1) Load the trained model
247250 using net_infer = my_transformer_cfg::network_type<false >;
248251 net_infer net;
249- if (dlib:: file_exists (model_file))
252+ if (file_exists (model_file))
250253 {
251- dlib:: deserialize (model_file) >> net;
252- std:: cout << " Loaded model from " << model_file << " \n " ;
254+ deserialize (model_file) >> net;
255+ cout << " Loaded model from " << model_file << " \n " ;
253256 }
254257 else
255258 {
256- std:: cerr << " Error: model file not found. Please run --train first.\n " ;
259+ cerr << " Error: model file not found. Please run --train first.\n " ;
257260 return 0 ;
258261 }
259- std:: cout << my_transformer_cfg::model_info::describe () << std:: endl;
260- std:: cout << " Model parameters: " << count_parameters (net) << std:: endl << std:: endl;
262+ cout << my_transformer_cfg::model_info::describe () << endl;
263+ cout << " Model parameters: " << count_parameters (net) << endl << endl;
261264
262265 // 2) Get the prompt from the included slm_data.h
263266 std::string prompt_text = shakespeare_prompt;
264267 if (prompt_text.empty ())
265268 {
266- std:: cerr << " No prompt found in slm_data.h.\n " ;
269+ cerr << " No prompt found in slm_data.h.\n " ;
267270 return 0 ;
268271 }
269272 // If prompt is longer than max_seq_len, we keep only the first window
@@ -274,7 +277,7 @@ int main(int argc, char** argv)
274277 const auto prompt_tokens = char_based_tokenize (prompt_text);
275278
276279 // Put into a dlib matrix
277- dlib:: matrix<int , 0 , 1 > input_seq (max_seq_len, 1 );
280+ matrix<int , 0 , 1 > input_seq (max_seq_len, 1 );
278281 // Fill with pad if prompt is shorter than max_seq_len
279282 for (long i = 0 ; i < max_seq_len; ++i)
280283 {
@@ -284,7 +287,7 @@ int main(int argc, char** argv)
284287 input_seq (i, 0 ) = PAD_TOKEN;
285288 }
286289
287- std:: cout << " \n Initial prompt:\n " << prompt_text << " (...)\n\n\n Generated text:\n " << prompt_text;
290+ cout << " \n Initial prompt:\n " << prompt_text << " (...)\n\n\n Generated text:\n " << prompt_text;
288291
289292 // 3) Generate new text
290293 // We'll predict one character at a time, then shift the window
@@ -293,22 +296,22 @@ int main(int argc, char** argv)
293296 const int next_char = net (input_seq); // single inference
294297
295298 // Print the generated character
296- std:: cout << static_cast <char >(std::min (next_char, MAX_TOKEN_ID)) << std:: flush;
299+ cout << static_cast <char >(std::min (next_char, MAX_TOKEN_ID)) << flush;
297300
298301 // Shift left by 1
299302 for (long i = 0 ; i < max_seq_len - 1 ; ++i)
300303 input_seq (i, 0 ) = input_seq (i + 1 , 0 );
301304 input_seq (max_seq_len - 1 , 0 ) = std::min (next_char, MAX_TOKEN_ID);
302305 }
303306
304- std:: cout << " \n\n (end of generation)\n " ;
307+ cout << " \n\n (end of generation)\n " ;
305308 }
306309
307310 return 0 ;
308311 }
309- catch (std:: exception& e)
312+ catch (exception& e)
310313 {
311- std:: cerr << " Exception thrown: " << e.what () << std:: endl;
314+ cerr << " Exception thrown: " << e.what () << endl;
312315 return 1 ;
313316 }
314317}
0 commit comments