5656using namespace std ;
5757using namespace dlib ;
5858
59- namespace ernie
59+ namespace dlib
6060{
61+ /* !
62+ @class rotary_positional_embedding_
63+ @brief Implements Rotary Positional Embeddings (RoPE) for transformers
64+
65+ This layer applies rotary positional embeddings to queries and keys in
66+ self-attention layers, providing relative positional information without
67+ absolute position embeddings.
68+
69+ The implementation follows the RoPE formulation from [2], where positions
70+ are encoded through rotation matrices applied to pairs of dimensions.
71+ !*/
6172 class rotary_positional_embedding_ {
6273 public:
6374 explicit rotary_positional_embedding_ () = default;
@@ -386,7 +397,7 @@ namespace ernie
386397 struct model_info {
387398 static std::string describe () {
388399 std::stringstream ss;
389- ss << " ERNIE Transformer model configuration:\n "
400+ ss << " Transformer model configuration:\n "
390401 << " - vocabulary size: " << VOCAB_SIZE << " \n "
391402 << " - layers: " << NUM_LAYERS << " \n "
392403 << " - attention heads: " << NUM_HEADS << " \n "
@@ -674,9 +685,9 @@ int main(int argc, char** argv)
674685 command_line_parser parser;
675686 parser.add_option (" train" , " Train a transformer model on enwiki" );
676687 parser.add_option (" generate" , " Generate enwiki from a previously trained model" );
677- parser.add_option (" verify" , " Verify generated output against original enwiki " );
688+ parser.add_option (" verify" , " Verify generated output against original data " );
678689 parser.add_option (" tokenize-only" , " Only tokenize the input file and save tokens" );
679- parser.add_option (" enwiki" , " Path to the enwiki file" , 1 );
690+ parser.add_option (" enwiki" , " Path to the enwiki file (default: enwiki.txt) " , 1 );
680691 parser.add_option (" max-tokens" , " Maximum number of tokens to load in memory" , 1 );
681692 parser.add_option (" max-bytes" , " Maximum number of bytes to process from enwiki" , 1 );
682693 parser.add_option (" percent" , " Percentage of enwiki to process (0-100)" , 1 );
@@ -687,9 +698,9 @@ int main(int argc, char** argv)
687698 parser.add_option (" alpha" , " Set the weight decay for Adam (default: 0.004)" , 1 );
688699 parser.add_option (" beta1" , " Set Adam's first moment coefficient (default: 0.9)" , 1 );
689700 parser.add_option (" beta2" , " Set Adam's second moment coefficient (default: 0.999)" , 1 );
690- parser.add_option (" model-file" , " Path for model (default: ernie_model .dat)" , 1 );
701+ parser.add_option (" model-file" , " Path for model (default: slm_enwiki_model .dat)" , 1 );
691702 parser.add_option (" output-file" , " Path for output (default: enwiki_generated.txt)" , 1 );
692- parser.add_option (" tokenizer" , " Path to pre-trained tokenizer (default: ernie_tokenizer .vocab)" , 1 );
703+ parser.add_option (" tokenizer" , " Path to pre-trained tokenizer (default: enwiki_tokenizer .vocab)" , 1 );
693704 parser.add_option (" tokens-file" , " Path to pre-tokenized tokens file (optional)" , 1 );
694705 parser.add_option (" force-tokenize" , " Force tokenization even if tokens file exists" );
695706 parser.parse (argc, argv);
@@ -710,14 +721,14 @@ int main(int argc, char** argv)
710721 const double alpha = get_option (parser, " alpha" , 0.004 );
711722 const double beta1 = get_option (parser, " beta1" , 0.9 );
712723 const double beta2 = get_option (parser, " beta2" , 0.999 );
713- const std::string model_file = get_option (parser, " model-file" , " ernie_model .dat" );
724+ const std::string model_file = get_option (parser, " model-file" , " slm_enwiki_model .dat" );
714725 const std::string output_file = get_option (parser, " output-file" , " enwiki_generated.txt" );
715- const std::string enwiki_path = get_option (parser, " enwiki" , " enwiki" );
726+ const std::string enwiki_path = get_option (parser, " enwiki" , " enwiki.txt " );
716727 const long max_seq_len = 180 ;
717728 const long num_layers = 2 ;
718729 const long num_heads = 6 ;
719730 const long embedding_dim = 228 ;
720- const std::string tokenizer_path = get_option (parser, " tokenizer" , " ernie_tokenizer .vocab" );
731+ const std::string tokenizer_path = get_option (parser, " tokenizer" , " enwiki_tokenizer .vocab" );
721732 // Default number of prompt tokens = input sequence length
722733 const bool force_tokenize = parser.option (" force-tokenize" );
723734 const long num_tokens = 1000 ;
@@ -760,7 +771,7 @@ int main(int argc, char** argv)
760771 parser.option (" tokens-file" ).argument () :
761772 generate_tokens_filename (enwiki_path, max_bytes);
762773
763- using ernie_transformer = ernie:: transformer_config<
774+ using enwiki_transformer = transformer_config<
764775 num_tokens, // vocab_size
765776 num_layers, // number of layers
766777 num_heads, // number of attention heads
@@ -945,9 +956,9 @@ int main(int argc, char** argv)
945956 cout << " Created " << samples.size () << " training samples (100%)...\n " ;
946957
947958 // 5) Build and train the network
948- using net_type = ernie_transformer ::network_type<true >;
959+ using net_type = enwiki_transformer ::network_type<true >;
949960 net_type net;
950- cout << " Model architecture:\n " << ernie_transformer ::model_info::describe () << endl;
961+ cout << " Model architecture:\n " << enwiki_transformer ::model_info::describe () << endl;
951962 if (file_exists (model_file)) deserialize (model_file) >> net;
952963
953964 // Create trainer
@@ -958,7 +969,7 @@ int main(int argc, char** argv)
958969 // For perfect memorization, we allow more epochs without improvement
959970 trainer.set_iterations_without_progress_threshold (patience);
960971 trainer.set_max_num_epochs (max_epochs); // More epochs for perfect memorization
961- trainer.set_synchronization_file (" ernie_trainer .sync" , std::chrono::minutes (10 ));
972+ trainer.set_synchronization_file (" enwiki_trainer .sync" , std::chrono::minutes (10 ));
962973 trainer.be_quiet ();
963974
964975 // Custom training loop - trainer.train(samples, labels)
@@ -1027,27 +1038,29 @@ int main(int argc, char** argv)
10271038 net.clean ();
10281039 serialize (model_file) << net;
10291040 cout << " Model saved to " << model_file << " \n " ;
1030- std::remove (" ernie_trainer .sync" );
1031- std::remove (" ernie_trainer .sync_" );
1041+ std::remove (" enwiki_trainer .sync" );
1042+ std::remove (" enwiki_trainer .sync_" );
10321043
10331044 // Evaluate on training set
1034- if (!g_terminate_flag.load ()) {
1035- cout << " Evaluating model accuracy...\n " ;
1036- using net_infer = ernie_transformer::network_type<false >;
1037- net_infer g_infer = net;
1038- auto predicted = g_infer (samples);
1039- size_t correct = 0 ;
1040- for (size_t i = 0 ; i < labels.size (); ++i)
1041- if (predicted[i] == labels[i]) correct++;
1042- double accuracy = (double )correct / labels.size ();
1043- cout << " Training accuracy: " << (accuracy * 100.0 ) << " %\n " ;
1044-
1045- // We need perfect accuracy to reconstruct enwiki
1046- if (accuracy < 0.9999 ) {
1047- cout << " WARNING: Model accuracy is less than 99.99%. The model may not "
1048- << " perfectly reconstruct the input text.\n " ;
1045+ {
1046+ if (!g_terminate_flag.load ()) {
1047+ cout << " Evaluating model accuracy...\n " ;
1048+ using net_infer = enwiki_transformer::network_type<false >;
1049+ net_infer g_infer = net;
1050+ auto predicted = g_infer (samples);
1051+ size_t correct = 0 ;
1052+ for (size_t i = 0 ; i < labels.size (); ++i)
1053+ if (predicted[i] == labels[i]) correct++;
1054+ double accuracy = (double )correct / labels.size ();
1055+ cout << " Training accuracy: " << (accuracy * 100.0 ) << " %\n " ;
1056+
1057+ // We need perfect accuracy to reconstruct enwiki
1058+ if (accuracy < 0.9999 ) {
1059+ cout << " WARNING: Model accuracy is less than 99.99%. The model may not "
1060+ << " perfectly reconstruct the input text.\n " ;
1061+ }
10491062 }
1050- }
1063+ }
10511064 }
10521065
10531066 // ----------------------------------------------------------------------------------------
@@ -1058,7 +1071,7 @@ int main(int argc, char** argv)
10581071 cout << " === GENERATION MODE ===\n " ;
10591072
10601073 // 1) Load the model
1061- using net_infer = ernie_transformer ::network_type<false >;
1074+ using net_infer = enwiki_transformer ::network_type<false >;
10621075 net_infer net;
10631076 if (file_exists (model_file)) {
10641077 deserialize (model_file) >> net;
0 commit comments