@@ -257,16 +257,27 @@ int main(int argc, char** argv)
257257 cout << " === FINE-TUNING MODE ===\n " ;
258258 cout << " Objective: specialize model for conversational Q&A with proper formatting\n\n " ;
259259
260- // Load tokenizer & last checkpoint from the global training
260+ // Setup trainer for fine-tuning
261261 std::string finetuned_model = model_file.substr (0 , model_file.find_last_of (' .' ))
262262 + " _finetuned.dat" ;
263+ train_net net;
264+ dnn_trainer<train_net, adam> trainer (net, adam (weight_decay, beta1, beta2), gpus);
265+ trainer.set_learning_rate (learning_rate);
266+ trainer.set_min_learning_rate (1e-7 );
267+ trainer.set_mini_batch_size (batch_size);
268+ trainer.set_max_num_epochs (max_epochs);
269+ trainer.set_iterations_without_progress_threshold (patience);
270+ trainer.set_synchronization_file (" chkpt-" + finetuned_model, std::chrono::minutes (5 ));
271+ trainer.be_quiet ();
263272
273+ // Load tokenizer & model
264274 bpe_tokenizer tokenizer;
265- if (file_exists (tokenizer_file)) {
266- cout << " Loading pre-trained tokenizer from: " << tokenizer_file << endl;
267- deserialize (tokenizer_file) >> tokenizer;
268- cout << " Tokenizer loaded successfully with vocabulary size: " << tokenizer.get_vocab_size () << endl;
269- }
275+ if (file_exists (model_file) &&
276+ !file_exists (" chkpt-" + finetuned_model)) deserialize (model_file) >> net >> tokenizer;
277+ else if (file_exists (finetuned_model) &&
278+ !file_exists (" chkpt-" + finetuned_model)) deserialize (finetuned_model) >> net >> tokenizer;
279+ else if (file_exists (tokenizer_file)) {
280+ deserialize (tokenizer_file) >> tokenizer; }
270281 else {
271282 cout << " Pre-trained tokenizer not found at: " << tokenizer_file << endl;
272283 return 1 ;
@@ -336,22 +347,6 @@ int main(int argc, char** argv)
336347 // Release memory
337348 qa_tokens.clear ();
338349
339- // Setup trainer for fine-tuning
340- train_net net;
341- if (!file_exists (" chkpt-" + model_file)) {
342- cerr << " Error: last checkpoint not found: " << (string (" chkpt-" ) + model_file) << " \n " ;
343- cerr << " Please run --train first to create base model using <slm_mixture_of_experts_ex>.\n " ;
344- return 1 ;
345- }
346- dnn_trainer<train_net, adam> trainer (net, adam (weight_decay, beta1, beta2), gpus);
347- trainer.set_learning_rate (learning_rate);
348- trainer.set_min_learning_rate (1e-7 );
349- trainer.set_mini_batch_size (batch_size);
350- trainer.set_max_num_epochs (max_epochs);
351- trainer.set_iterations_without_progress_threshold (patience);
352- trainer.set_synchronization_file (" chkpt-" + model_file, std::chrono::minutes (10 ));
353- trainer.be_quiet ();
354-
355350 cout << " Applying freezing strategy...\n " ;
356351 set_all_learning_rate_multipliers (net, 0.1 );
357352 layer<1 >(net).layer_details ().set_learning_rate_multiplier (1.0 ); // linear
@@ -392,7 +387,7 @@ int main(int argc, char** argv)
392387 steps += batch_samples.size ();
393388
394389 // Progress reporting
395- if (batches_count++ % 50 == 0 ) {
390+ if (batches_count++ % 100 == 0 ) {
396391 double avg_loss = total_loss / batches_seen;
397392 auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
398393 std::chrono::high_resolution_clock::now () - epoch_start).count ();
0 commit comments