Skip to content

Commit 1e09fc1

Browse files
committed
actually take advantadge of using namespace std;
1 parent 4d116ab commit 1e09fc1

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

examples/slm_basic_train_ex.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ int main(int argc, char** argv)
154154
// ----------------------------------------------------------------------------------------
155155
if (parser.option("train"))
156156
{
157-
std::cout << "=== TRAIN MODE ===\n";
157+
cout << "=== TRAIN MODE ===\n";
158158

159159
// 1) Prepare training data (simple approach)
160160
// We will store characters from shakespeare_text into a vector
@@ -163,7 +163,7 @@ int main(int argc, char** argv)
163163
auto full_tokens = char_based_tokenize(shakespeare_text);
164164
if (full_tokens.empty())
165165
{
166-
std::cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
166+
cerr << "ERROR: The Shakespeare text is empty. Please provide a valid training text.\n";
167167
return 0;
168168
}
169169

@@ -173,13 +173,13 @@ int main(int argc, char** argv)
173173
: 0;
174174

175175
// Display the size of the training text and the number of sequences
176-
std::cout << "Training text size: " << full_tokens.size() << " characters\n";
177-
std::cout << "Maximum number of sequences: " << max_sequences << "\n";
176+
cout << "Training text size: " << full_tokens.size() << " characters\n";
177+
cout << "Maximum number of sequences: " << max_sequences << "\n";
178178

179179
// Check if the text is too short
180180
if (max_sequences == 0)
181181
{
182-
std::cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
182+
cerr << "ERROR: The Shakespeare text is too short for training. It must contain at least "
183183
<< (max_seq_len + 1) << " characters.\n";
184184
return 0;
185185
}
@@ -203,7 +203,7 @@ int main(int argc, char** argv)
203203
// Shuffle samples and labels if the --shuffle option is enabled
204204
if (parser.option("shuffle"))
205205
{
206-
std::cout << "Shuffling training sequences and labels...\n";
206+
cout << "Shuffling training sequences and labels...\n";
207207
shuffle_samples_and_labels(samples, labels);
208208
}
209209

@@ -232,41 +232,41 @@ int main(int argc, char** argv)
232232
if (predicted[i] == labels[i])
233233
correct++;
234234
double accuracy = (double)correct / labels.size();
235-
std::cout << "Training accuracy (on this sample set): " << accuracy << "\n";
235+
cout << "Training accuracy (on this sample set): " << accuracy << "\n";
236236

237237
// 7) Save the model
238238
net.clean();
239239
serialize(model_file) << net;
240-
std::cout << "Model saved to " << model_file << "\n";
240+
cout << "Model saved to " << model_file << "\n";
241241
}
242242

243243
// ----------------------------------------------------------------------------------------
244244
// Generate mode
245245
// ----------------------------------------------------------------------------------------
246246
if (parser.option("generate"))
247247
{
248-
std::cout << "=== GENERATE MODE ===\n";
248+
cout << "=== GENERATE MODE ===\n";
249249
// 1) Load the trained model
250250
using net_infer = my_transformer_cfg::network_type<false>;
251251
net_infer net;
252252
if (file_exists(model_file))
253253
{
254254
deserialize(model_file) >> net;
255-
std::cout << "Loaded model from " << model_file << "\n";
255+
cout << "Loaded model from " << model_file << "\n";
256256
}
257257
else
258258
{
259-
std::cerr << "Error: model file not found. Please run --train first.\n";
259+
cerr << "Error: model file not found. Please run --train first.\n";
260260
return 0;
261261
}
262-
std::cout << my_transformer_cfg::model_info::describe() << std::endl;
263-
std::cout << "Model parameters: " << count_parameters(net) << std::endl << std::endl;
262+
cout << my_transformer_cfg::model_info::describe() << endl;
263+
cout << "Model parameters: " << count_parameters(net) << endl << endl;
264264

265265
// 2) Get the prompt from the included slm_data.h
266266
std::string prompt_text = shakespeare_prompt;
267267
if (prompt_text.empty())
268268
{
269-
std::cerr << "No prompt found in slm_data.h.\n";
269+
cerr << "No prompt found in slm_data.h.\n";
270270
return 0;
271271
}
272272
// If prompt is longer than max_seq_len, we keep only the first window
@@ -287,7 +287,7 @@ int main(int argc, char** argv)
287287
input_seq(i, 0) = PAD_TOKEN;
288288
}
289289

290-
std::cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;
290+
cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;
291291

292292
// 3) Generate new text
293293
// We'll predict one character at a time, then shift the window
@@ -296,22 +296,22 @@ int main(int argc, char** argv)
296296
const int next_char = net(input_seq); // single inference
297297

298298
// Print the generated character
299-
std::cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << std::flush;
299+
cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << flush;
300300

301301
// Shift left by 1
302302
for (long i = 0; i < max_seq_len - 1; ++i)
303303
input_seq(i, 0) = input_seq(i + 1, 0);
304304
input_seq(max_seq_len - 1, 0) = std::min(next_char, MAX_TOKEN_ID);
305305
}
306306

307-
std::cout << "\n\n(end of generation)\n";
307+
cout << "\n\n(end of generation)\n";
308308
}
309309

310310
return 0;
311311
}
312-
catch (std::exception& e)
312+
catch (exception& e)
313313
{
314-
std::cerr << "Exception thrown: " << e.what() << std::endl;
314+
cerr << "Exception thrown: " << e.what() << endl;
315315
return 1;
316316
}
317317
}

0 commit comments

Comments
 (0)