Skip to content

Commit 952513c

Browse files
committed
Update
1 parent c6f6979 commit 952513c

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

examples/slm_chatbot_ex.cpp

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,27 @@ int main(int argc, char** argv)
257257
cout << "=== FINE-TUNING MODE ===\n";
258258
cout << "Objective: specialize model for conversational Q&A with proper formatting\n\n";
259259

260-
// Load tokenizer & last checkpoint from the global training
260+
// Setup trainer for fine-tuning
261261
std::string finetuned_model = model_file.substr(0, model_file.find_last_of('.'))
262262
+ "_finetuned.dat";
263+
train_net net;
264+
dnn_trainer<train_net, adam> trainer(net, adam(weight_decay, beta1, beta2), gpus);
265+
trainer.set_learning_rate(learning_rate);
266+
trainer.set_min_learning_rate(1e-7);
267+
trainer.set_mini_batch_size(batch_size);
268+
trainer.set_max_num_epochs(max_epochs);
269+
trainer.set_iterations_without_progress_threshold(patience);
270+
trainer.set_synchronization_file("chkpt-" + finetuned_model, std::chrono::minutes(5));
271+
trainer.be_quiet();
263272

273+
// Load tokenizer & model
264274
bpe_tokenizer tokenizer;
265-
if (file_exists(tokenizer_file)) {
266-
cout << "Loading pre-trained tokenizer from: " << tokenizer_file << endl;
267-
deserialize(tokenizer_file) >> tokenizer;
268-
cout << "Tokenizer loaded successfully with vocabulary size: " << tokenizer.get_vocab_size() << endl;
269-
}
275+
if (file_exists(model_file) &&
276+
!file_exists("chkpt-" + finetuned_model)) deserialize(model_file) >> net >> tokenizer;
277+
else if (file_exists(finetuned_model) &&
278+
!file_exists("chkpt-" + finetuned_model)) deserialize(finetuned_model) >> net >> tokenizer;
279+
else if (file_exists(tokenizer_file)) {
280+
deserialize(tokenizer_file) >> tokenizer; }
270281
else {
271282
cout << "Pre-trained tokenizer not found at: " << tokenizer_file << endl;
272283
return 1;
@@ -336,22 +347,6 @@ int main(int argc, char** argv)
336347
// Release memory
337348
qa_tokens.clear();
338349

339-
// Setup trainer for fine-tuning
340-
train_net net;
341-
if (!file_exists("chkpt-" + model_file)) {
342-
cerr << "Error: last checkpoint not found: " << (string("chkpt-") + model_file) << "\n";
343-
cerr << "Please run --train first to create base model using <slm_mixture_of_experts_ex>.\n";
344-
return 1;
345-
}
346-
dnn_trainer<train_net, adam> trainer(net, adam(weight_decay, beta1, beta2), gpus);
347-
trainer.set_learning_rate(learning_rate);
348-
trainer.set_min_learning_rate(1e-7);
349-
trainer.set_mini_batch_size(batch_size);
350-
trainer.set_max_num_epochs(max_epochs);
351-
trainer.set_iterations_without_progress_threshold(patience);
352-
trainer.set_synchronization_file("chkpt-" + model_file, std::chrono::minutes(10));
353-
trainer.be_quiet();
354-
355350
cout << "Applying freezing strategy...\n";
356351
set_all_learning_rate_multipliers(net, 0.1);
357352
layer<1>(net).layer_details().set_learning_rate_multiplier(1.0); // linear
@@ -392,7 +387,7 @@ int main(int argc, char** argv)
392387
steps += batch_samples.size();
393388

394389
// Progress reporting
395-
if (batches_count++ % 50 == 0) {
390+
if (batches_count++ % 100 == 0) {
396391
double avg_loss = total_loss / batches_seen;
397392
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
398393
std::chrono::high_resolution_clock::now() - epoch_start).count();

0 commit comments

Comments
 (0)