Skip to content

Commit 6460e81

Browse files
committed
Update
1 parent e496c7d commit 6460e81

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed

examples/slm_mixture_of_experts_ex.cpp

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,200 @@ moe_param_info get_moe_param_info(const net_type& net, long num_layers)
356356
return info;
357357
}
358358

359+
// Reads entire file content into a string.
360+
std::string read_file_content(const std::string& filepath)
361+
{
362+
std::ifstream file(filepath, std::ios::binary);
363+
if (!file) {
364+
cerr << "Warning: Cannot open file: " << filepath << "\n";
365+
return "";
366+
}
367+
368+
std::stringstream buffer;
369+
buffer << file.rdbuf();
370+
return buffer.str();
371+
}
372+
373+
// Replaces all occurrences of double newlines ("\n\n") with "@@" delimiter.
374+
std::string normalize_paragraph_delimiters(const std::string& text)
375+
{
376+
std::string result;
377+
result.reserve(text.size());
378+
379+
size_t i = 0;
380+
while (i < text.size()) {
381+
// Check for double (or more) newlines
382+
if (i + 1 < text.size() && text[i] == '\n' && text[i + 1] == '\n') {
383+
result += "@@";
384+
i += 2;
385+
386+
// Skip any additional consecutive newlines
387+
while (i < text.size() && text[i] == '\n') ++i;
388+
}
389+
else {
390+
result += text[i];
391+
++i;
392+
}
393+
}
394+
395+
return result;
396+
}
397+
398+
// Recursively collects all text files from a directory using Dlib's directory class.
399+
void collect_text_files_recursive(
400+
const directory& dir,
401+
std::vector<std::string>& text_files,
402+
size_t max_files = 0
403+
)
404+
{
405+
// Process files in current directory
406+
for (const auto& file : dir.get_files()) {
407+
if (max_files > 0 && text_files.size() >= max_files) return;
408+
409+
// Check if it's a text file using file type detection
410+
file_content_type content_type;
411+
if (detect_file_type(file.full_name(), content_type)) {
412+
text_files.push_back(file.full_name());
413+
cout << " Found text file: " << file.name() << "\n";
414+
}
415+
}
416+
417+
// Recursively process subdirectories
418+
for (const auto& subdir : dir.get_dirs()) {
419+
if (max_files > 0 && text_files.size() >= max_files) {
420+
return;
421+
}
422+
collect_text_files_recursive(subdir, text_files, max_files);
423+
}
424+
}
425+
426+
// Loads external text data from a file or directory
427+
std::string load_external_data(
428+
const std::string& path,
429+
bool normalize_delimiters = true
430+
)
431+
{
432+
std::string combined_text;
433+
434+
try {
435+
// Try as directory first
436+
directory dir(path);
437+
438+
cout << "Scanning directory recursively: " << path << "\n";
439+
440+
std::vector<std::string> text_files;
441+
collect_text_files_recursive(dir, text_files);
442+
443+
cout << "Found " << text_files.size() << " text file(s)\n";
444+
445+
if (text_files.empty()) {
446+
cerr << "Warning: No text files found in directory\n";
447+
return "";
448+
}
449+
450+
// Sort files for consistent ordering
451+
std::sort(text_files.begin(), text_files.end());
452+
453+
// Concatenate all files with delimiter
454+
size_t total_bytes = 0;
455+
for (const auto& filepath : text_files) {
456+
std::string content = read_file_content(filepath);
457+
if (!content.empty()) {
458+
combined_text += content;
459+
460+
// Ensure content ends with delimiter for next file
461+
if (!combined_text.empty() &&
462+
combined_text.size() >= 2 &&
463+
combined_text.substr(combined_text.size() - 2) != "@@") {
464+
combined_text += "@@";
465+
}
466+
467+
total_bytes += content.size();
468+
}
469+
}
470+
471+
cout << "Total loaded: " << total_bytes << " bytes from "
472+
<< text_files.size() << " file(s)\n";
473+
}
474+
catch (const directory::dir_not_found&) {
475+
// Not a directory, try as single file
476+
cout << "Loading single text file: " << path << "\n";
477+
478+
// Verify it's a text file
479+
file_content_type content_type;
480+
if (!detect_file_type(path, content_type)) {
481+
cerr << "Error: File does not appear to be text: " << path << "\n";
482+
cerr << "Only plain text files are supported for training.\n";
483+
return "";
484+
}
485+
486+
combined_text = read_file_content(path);
487+
488+
if (combined_text.empty()) {
489+
cerr << "Warning: File is empty or could not be read\n";
490+
return "";
491+
}
492+
493+
cout << "Loaded " << combined_text.size() << " bytes from file\n";
494+
}
495+
catch (const std::exception& e) {
496+
cerr << "Error loading external data: " << e.what() << "\n";
497+
return "";
498+
}
499+
500+
// Normalize paragraph delimiters if requested
501+
if (normalize_delimiters && !combined_text.empty()) {
502+
size_t original_size = combined_text.size();
503+
combined_text = normalize_paragraph_delimiters(combined_text);
504+
}
505+
506+
return combined_text;
507+
}
508+
509+
// Parses text with @@ delimiters into individual segments.
510+
std::vector<std::string> parse_delimited_segments(const std::string& text)
511+
{
512+
std::vector<std::string> segments;
513+
std::string delimiter = "@@";
514+
515+
size_t start = 0;
516+
size_t end = text.find(delimiter);
517+
518+
while (end != std::string::npos) {
519+
std::string segment = text.substr(start, end - start);
520+
521+
// Trim whitespace
522+
size_t first = segment.find_first_not_of(" \t\n\r");
523+
if (first != std::string::npos) {
524+
size_t last = segment.find_last_not_of(" \t\n\r");
525+
segment = segment.substr(first, last - first + 1);
526+
527+
// Add non-empty segments
528+
if (!segment.empty()) {
529+
segments.push_back(segment);
530+
}
531+
}
532+
533+
start = end + delimiter.length();
534+
end = text.find(delimiter, start);
535+
}
536+
537+
// Handle last segment
538+
if (start < text.size()) {
539+
std::string segment = text.substr(start);
540+
size_t first = segment.find_first_not_of(" \t\n\r");
541+
if (first != std::string::npos) {
542+
size_t last = segment.find_last_not_of(" \t\n\r");
543+
segment = segment.substr(first, last - first + 1);
544+
if (!segment.empty()) {
545+
segments.push_back(segment);
546+
}
547+
}
548+
}
549+
550+
return segments;
551+
}
552+
359553
int main(int argc, char** argv)
360554
{
361555
try
@@ -376,6 +570,7 @@ int main(int argc, char** argv)
376570
parser.add_option("model-file", "Path for model (default: dlib_lm_moe_model.dat)", 1);
377571
parser.add_option("tokenizer-file", "Path for tokenizer (default: dlib_lm_tokenizer.vocab)", 1);
378572
parser.add_option("output-file", "Path for generated output (default: generated_text.txt)", 1);
573+
parser.add_option("external-data", "Path to external text data", 1);
379574
parser.parse(argc, argv);
380575

381576
if (parser.number_of_arguments() == 0 &&
@@ -422,6 +617,38 @@ int main(int argc, char** argv)
422617
};
423618
auto text_segments = get_dataset_as_segments(text_datasets);
424619

620+
// Load external data if provided
621+
std::string external_corpus_for_tokenizer;
622+
if (parser.option("external-data")) {
623+
std::string external_path = parser.option("external-data").argument();
624+
cout << "Externa source: " << external_path << "\n";
625+
626+
std::string external_text = load_external_data(external_path, true);
627+
if (!external_text.empty()) {
628+
// Store raw text for tokenizer training (if needed later)
629+
external_corpus_for_tokenizer = external_text;
630+
631+
// Parse into segments for training
632+
cout << "Parsing external data into segments...\n";
633+
auto external_segments = parse_delimited_segments(external_text);
634+
cout << "Parsed " << external_segments.size() << " external segments\n";
635+
636+
if (!external_segments.empty()) {
637+
// Add to training data
638+
size_t original_count = text_segments.size();
639+
text_segments.insert(text_segments.end(),
640+
external_segments.begin(), external_segments.end());
641+
642+
cout << "Training segments: " << original_count
643+
<< " (internal) + " << external_segments.size()
644+
<< " (external) = " << text_segments.size() << " (total)\n";
645+
}
646+
}
647+
else {
648+
cerr << "Warning: no valid external data loaded, continuing with internal datasets only\n";
649+
}
650+
}
651+
425652
// Tokens filename
426653
const std::string tokens_file = "dlib_datasets_tokens.bin";
427654

@@ -489,6 +716,9 @@ int main(int argc, char** argv)
489716
+ get_dataset_as_text(dataset_id::BLACK_HOLE_QA_PARTC) + delimiter
490717
+ get_dataset_as_text(dataset_id::GENERAL_KNOWLEDGE);
491718

719+
if (!external_corpus_for_tokenizer.empty())
720+
tokenizer_corpus += delimiter + external_corpus_for_tokenizer;
721+
492722
// Replace all "@@" delimiters with spaces
493723
size_t pos = 0;
494724
while ((pos = tokenizer_corpus.find(delimiter, pos)) != std::string::npos) {

0 commit comments

Comments
 (0)