Skip to content

Commit 19121aa

Browse files
committed
simplify code a little
1 parent 3ab9801 commit 19121aa

File tree

1 file changed

+20
-62
lines changed

1 file changed

+20
-62
lines changed

examples/slm_basic_train_ex.cpp

Lines changed: 20 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,17 @@
5050
// ----------------------------------------------------------------------------------------
5151

5252
// We treat each character as a token ID in [0..255].
53-
static const int MAX_TOKEN_ID = 255;
54-
static const int PAD_TOKEN = 256; // an extra "pad" token if needed
53+
const int MAX_TOKEN_ID = 255;
54+
const int PAD_TOKEN = 256; // an extra "pad" token if needed
5555

5656
// For simplicity, we assume each line from shakespeare_text is appended, ignoring them.
57-
static std::vector<int> char_based_tokenize(const std::string& text)
57+
std::vector<int> char_based_tokenize(const std::string& text)
5858
{
5959
std::vector<int> tokens;
6060
tokens.reserve(text.size());
61-
for (unsigned char c : text)
61+
for (const int c : text)
6262
{
63-
tokens.push_back(std::min<int>(c, MAX_TOKEN_ID));
63+
tokens.push_back(std::min(c, MAX_TOKEN_ID));
6464
}
6565
return tokens;
6666
}
@@ -108,44 +108,18 @@ int main(int argc, char** argv)
108108

109109
if (parser.number_of_arguments() == 0 && !parser.option("train") && !parser.option("generate"))
110110
{
111-
std::cout << "Usage:\n"
112-
<< " --train : Train a small transformer model on the Shakespeare text\n"
113-
<< " --generate : Generate text from a trained model using a prompt\n"
114-
<< " --learning-rate <value> : Set the learning rate for training (default: 1e-4)\n"
115-
<< " --batch-size <value> : Set the mini-batch size for training (default: 64)\n"
116-
<< " --generation-length <value> : Set the length of generated text (default: 400)\n"
117-
<< " --alpha <value> : Set the initial learning rate for Adam optimizer (default: 0.004)\n"
118-
<< " --beta1 <value> : Set the decay rate for the first moment estimate (default: 0.9)\n"
119-
<< " --beta2 <value> : Set the decay rate for the second moment estimate (default: 0.999)\n"
120-
<< " --max-samples <value> : Set the maximum number of training samples (default: 50000)\n"
121-
<< " --shuffle : Shuffle training sequences and labels before training (default: false)\n";
111+
parser.print_options();
122112
return 0;
123113
}
124114

125115
// Default values
126-
double learning_rate = 1e-4;
127-
long batch_size = 64;
128-
int generation_length = 400;
129-
double alpha = 0.004; // Initial learning rate for Adam
130-
double beta1 = 0.9; // Decay rate for the first moment estimate
131-
double beta2 = 0.999; // Decay rate for the second moment estimate
132-
size_t max_samples = 50000; // Default maximum number of training samples
133-
134-
// Override defaults if options are provided
135-
if (parser.option("learning-rate"))
136-
learning_rate = std::stod(parser.option("learning-rate").argument());
137-
if (parser.option("batch-size"))
138-
batch_size = std::stol(parser.option("batch-size").argument());
139-
if (parser.option("generation-length"))
140-
generation_length = std::stoi(parser.option("generation-length").argument());
141-
if (parser.option("alpha"))
142-
alpha = std::stod(parser.option("alpha").argument());
143-
if (parser.option("beta1"))
144-
beta1 = std::stod(parser.option("beta1").argument());
145-
if (parser.option("beta2"))
146-
beta2 = std::stod(parser.option("beta2").argument());
147-
if (parser.option("max-samples"))
148-
max_samples = std::stoul(parser.option("max-samples").argument());
116+
const double learning_rate = get_option(parser, "learning-rate", 1e-4);
117+
const long batch_size = get_option(parser, "batch-size", 64);
118+
const int generation_length = get_option(parser, "generation-length", 400);
119+
const double alpha = get_option(parser, "alpha", 0.004); // Initial learning rate for Adam
120+
const double beta1 = get_option(parser, "beta1", 0.9); // Decay rate for the first moment estimate
121+
const double beta2 = get_option(parser, "beta2", 0.999); // Decay rate for the second moment estimate
122+
const size_t max_samples = get_option(parser, "max-samples",50000); // Default maximum number of training samples
149123

150124
// We define a minimal config for demonstration
151125
const long vocab_size = 257; // 0..255 for chars + 1 pad token
@@ -297,7 +271,7 @@ int main(int argc, char** argv)
297271
prompt_text.erase(prompt_text.begin() + max_seq_len, prompt_text.end());
298272

299273
// Convert prompt to a token sequence
300-
auto prompt_tokens = char_based_tokenize(prompt_text);
274+
const auto prompt_tokens = char_based_tokenize(prompt_text);
301275

302276
// Put into a dlib matrix
303277
dlib::matrix<int, 0, 1> input_seq(max_seq_len, 1);
@@ -310,37 +284,21 @@ int main(int argc, char** argv)
310284
input_seq(i, 0) = PAD_TOKEN;
311285
}
312286

313-
std::cout << "Initial prompt:\n" << prompt_text << " (...)\n\nGenerated text:\n" << prompt_text;
287+
std::cout << "\nInitial prompt:\n" << prompt_text << " (...)\n\n\nGenerated text:\n" << prompt_text;
314288

315289
// 3) Generate new text
316290
// We'll predict one character at a time, then shift the window
317-
// until the total length is at least generation_length and we encounter two newlines.
318-
std::string generated_text = prompt_text;
319-
bool stop_generation = false;
320-
321-
while (generated_text.size() < (size_t)generation_length || !stop_generation)
291+
for (int i = 0; i < generation_length; ++i)
322292
{
323-
unsigned long next_char = net(input_seq); // single inference
324-
325-
// Append the generated character to the text
326-
generated_text += (char)(std::min<unsigned long>(next_char, MAX_TOKEN_ID));
293+
const int next_char = net(input_seq); // single inference
327294

328295
// Print the generated character
329-
std::cout << (char)(std::min<unsigned long>(next_char, MAX_TOKEN_ID));
296+
std::cout << static_cast<char>(std::min(next_char, MAX_TOKEN_ID)) << std::flush;
330297

331298
// Shift left by 1
332299
for (long i = 0; i < max_seq_len - 1; ++i)
333300
input_seq(i, 0) = input_seq(i + 1, 0);
334-
input_seq(max_seq_len - 1, 0) = (int)std::min<unsigned long>(next_char, MAX_TOKEN_ID);
335-
336-
// Check if the last two characters are newlines
337-
if (generated_text.size() >= 2 &&
338-
generated_text[generated_text.size() - 1] == '\n' &&
339-
generated_text[generated_text.size() - 2] == '\n')
340-
{
341-
// Stop generation if the minimum length is reached
342-
if (generated_text.size() >= (size_t)generation_length) stop_generation = true;
343-
}
301+
input_seq(max_seq_len - 1, 0) = std::min(next_char, MAX_TOKEN_ID);
344302
}
345303

346304
std::cout << "\n\n(end of generation)\n";
@@ -391,4 +349,4 @@ int main(int argc, char** argv)
391349
* > QUEEN ELIZABETH:
392350
* > I go. Write to me very shortly.
393351
* > And you shall understand from me her mind.
394-
*/
352+
*/

0 commit comments

Comments
 (0)