@@ -356,6 +356,200 @@ moe_param_info get_moe_param_info(const net_type& net, long num_layers)
356356 return info;
357357}
358358
359+ // Reads entire file content into a string.
360+ std::string read_file_content (const std::string& filepath)
361+ {
362+ std::ifstream file (filepath, std::ios::binary);
363+ if (!file) {
364+ cerr << " Warning: Cannot open file: " << filepath << " \n " ;
365+ return " " ;
366+ }
367+
368+ std::stringstream buffer;
369+ buffer << file.rdbuf ();
370+ return buffer.str ();
371+ }
372+
373+ // Replaces all occurrences of double newlines ("\n\n") with "@@" delimiter.
374+ std::string normalize_paragraph_delimiters (const std::string& text)
375+ {
376+ std::string result;
377+ result.reserve (text.size ());
378+
379+ size_t i = 0 ;
380+ while (i < text.size ()) {
381+ // Check for double (or more) newlines
382+ if (i + 1 < text.size () && text[i] == ' \n ' && text[i + 1 ] == ' \n ' ) {
383+ result += " @@" ;
384+ i += 2 ;
385+
386+ // Skip any additional consecutive newlines
387+ while (i < text.size () && text[i] == ' \n ' ) ++i;
388+ }
389+ else {
390+ result += text[i];
391+ ++i;
392+ }
393+ }
394+
395+ return result;
396+ }
397+
398+ // Recursively collects all text files from a directory using Dlib's directory class.
399+ void collect_text_files_recursive (
400+ const directory& dir,
401+ std::vector<std::string>& text_files,
402+ size_t max_files = 0
403+ )
404+ {
405+ // Process files in current directory
406+ for (const auto & file : dir.get_files ()) {
407+ if (max_files > 0 && text_files.size () >= max_files) return ;
408+
409+ // Check if it's a text file using file type detection
410+ file_content_type content_type;
411+ if (detect_file_type (file.full_name (), content_type)) {
412+ text_files.push_back (file.full_name ());
413+ cout << " Found text file: " << file.name () << " \n " ;
414+ }
415+ }
416+
417+ // Recursively process subdirectories
418+ for (const auto & subdir : dir.get_dirs ()) {
419+ if (max_files > 0 && text_files.size () >= max_files) {
420+ return ;
421+ }
422+ collect_text_files_recursive (subdir, text_files, max_files);
423+ }
424+ }
425+
426+ // Loads external text data from a file or directory
427+ std::string load_external_data (
428+ const std::string& path,
429+ bool normalize_delimiters = true
430+ )
431+ {
432+ std::string combined_text;
433+
434+ try {
435+ // Try as directory first
436+ directory dir (path);
437+
438+ cout << " Scanning directory recursively: " << path << " \n " ;
439+
440+ std::vector<std::string> text_files;
441+ collect_text_files_recursive (dir, text_files);
442+
443+ cout << " Found " << text_files.size () << " text file(s)\n " ;
444+
445+ if (text_files.empty ()) {
446+ cerr << " Warning: No text files found in directory\n " ;
447+ return " " ;
448+ }
449+
450+ // Sort files for consistent ordering
451+ std::sort (text_files.begin (), text_files.end ());
452+
453+ // Concatenate all files with delimiter
454+ size_t total_bytes = 0 ;
455+ for (const auto & filepath : text_files) {
456+ std::string content = read_file_content (filepath);
457+ if (!content.empty ()) {
458+ combined_text += content;
459+
460+ // Ensure content ends with delimiter for next file
461+ if (!combined_text.empty () &&
462+ combined_text.size () >= 2 &&
463+ combined_text.substr (combined_text.size () - 2 ) != " @@" ) {
464+ combined_text += " @@" ;
465+ }
466+
467+ total_bytes += content.size ();
468+ }
469+ }
470+
471+ cout << " Total loaded: " << total_bytes << " bytes from "
472+ << text_files.size () << " file(s)\n " ;
473+ }
474+ catch (const directory::dir_not_found&) {
475+ // Not a directory, try as single file
476+ cout << " Loading single text file: " << path << " \n " ;
477+
478+ // Verify it's a text file
479+ file_content_type content_type;
480+ if (!detect_file_type (path, content_type)) {
481+ cerr << " Error: File does not appear to be text: " << path << " \n " ;
482+ cerr << " Only plain text files are supported for training.\n " ;
483+ return " " ;
484+ }
485+
486+ combined_text = read_file_content (path);
487+
488+ if (combined_text.empty ()) {
489+ cerr << " Warning: File is empty or could not be read\n " ;
490+ return " " ;
491+ }
492+
493+ cout << " Loaded " << combined_text.size () << " bytes from file\n " ;
494+ }
495+ catch (const std::exception& e) {
496+ cerr << " Error loading external data: " << e.what () << " \n " ;
497+ return " " ;
498+ }
499+
500+ // Normalize paragraph delimiters if requested
501+ if (normalize_delimiters && !combined_text.empty ()) {
502+ size_t original_size = combined_text.size ();
503+ combined_text = normalize_paragraph_delimiters (combined_text);
504+ }
505+
506+ return combined_text;
507+ }
508+
509+ // Parses text with @@ delimiters into individual segments.
510+ std::vector<std::string> parse_delimited_segments (const std::string& text)
511+ {
512+ std::vector<std::string> segments;
513+ std::string delimiter = " @@" ;
514+
515+ size_t start = 0 ;
516+ size_t end = text.find (delimiter);
517+
518+ while (end != std::string::npos) {
519+ std::string segment = text.substr (start, end - start);
520+
521+ // Trim whitespace
522+ size_t first = segment.find_first_not_of (" \t\n\r " );
523+ if (first != std::string::npos) {
524+ size_t last = segment.find_last_not_of (" \t\n\r " );
525+ segment = segment.substr (first, last - first + 1 );
526+
527+ // Add non-empty segments
528+ if (!segment.empty ()) {
529+ segments.push_back (segment);
530+ }
531+ }
532+
533+ start = end + delimiter.length ();
534+ end = text.find (delimiter, start);
535+ }
536+
537+ // Handle last segment
538+ if (start < text.size ()) {
539+ std::string segment = text.substr (start);
540+ size_t first = segment.find_first_not_of (" \t\n\r " );
541+ if (first != std::string::npos) {
542+ size_t last = segment.find_last_not_of (" \t\n\r " );
543+ segment = segment.substr (first, last - first + 1 );
544+ if (!segment.empty ()) {
545+ segments.push_back (segment);
546+ }
547+ }
548+ }
549+
550+ return segments;
551+ }
552+
359553int main (int argc, char ** argv)
360554{
361555 try
@@ -376,6 +570,7 @@ int main(int argc, char** argv)
376570 parser.add_option (" model-file" , " Path for model (default: dlib_lm_moe_model.dat)" , 1 );
377571 parser.add_option (" tokenizer-file" , " Path for tokenizer (default: dlib_lm_tokenizer.vocab)" , 1 );
378572 parser.add_option (" output-file" , " Path for generated output (default: generated_text.txt)" , 1 );
573+ parser.add_option (" external-data" , " Path to external text data" , 1 );
379574 parser.parse (argc, argv);
380575
381576 if (parser.number_of_arguments () == 0 &&
@@ -422,6 +617,38 @@ int main(int argc, char** argv)
422617 };
423618 auto text_segments = get_dataset_as_segments (text_datasets);
424619
620+ // Load external data if provided
621+ std::string external_corpus_for_tokenizer;
622+ if (parser.option (" external-data" )) {
623+ std::string external_path = parser.option (" external-data" ).argument ();
624+ cout << " Externa source: " << external_path << " \n " ;
625+
626+ std::string external_text = load_external_data (external_path, true );
627+ if (!external_text.empty ()) {
628+ // Store raw text for tokenizer training (if needed later)
629+ external_corpus_for_tokenizer = external_text;
630+
631+ // Parse into segments for training
632+ cout << " Parsing external data into segments...\n " ;
633+ auto external_segments = parse_delimited_segments (external_text);
634+ cout << " Parsed " << external_segments.size () << " external segments\n " ;
635+
636+ if (!external_segments.empty ()) {
637+ // Add to training data
638+ size_t original_count = text_segments.size ();
639+ text_segments.insert (text_segments.end (),
640+ external_segments.begin (), external_segments.end ());
641+
642+ cout << " Training segments: " << original_count
643+ << " (internal) + " << external_segments.size ()
644+ << " (external) = " << text_segments.size () << " (total)\n " ;
645+ }
646+ }
647+ else {
648+ cerr << " Warning: no valid external data loaded, continuing with internal datasets only\n " ;
649+ }
650+ }
651+
425652 // Tokens filename
426653 const std::string tokens_file = " dlib_datasets_tokens.bin" ;
427654
@@ -489,6 +716,9 @@ int main(int argc, char** argv)
489716 + get_dataset_as_text (dataset_id::BLACK_HOLE_QA_PARTC) + delimiter
490717 + get_dataset_as_text (dataset_id::GENERAL_KNOWLEDGE);
491718
719+ if (!external_corpus_for_tokenizer.empty ())
720+ tokenizer_corpus += delimiter + external_corpus_for_tokenizer;
721+
492722 // Replace all "@@" delimiters with spaces
493723 size_t pos = 0 ;
494724 while ((pos = tokenizer_corpus.find (delimiter, pos)) != std::string::npos) {
0 commit comments