Skip to content

Commit 71a2164

Browse files
committed
Add support for concurrent file processing in ingestion module
- Introduced `max_concurrent_files` option in `IngestConfig` to control the number of files processed simultaneously. - Updated `process_directory_tree` to utilize a semaphore for managing concurrent tasks. - Refactored file processing functions to support parallel execution. - Enhanced command-line interface to accept `max_concurrent` argument. - Improved metadata loading with modern file reading methods.
1 parent a447728 commit 71a2164

File tree

3 files changed

+163
-38
lines changed

3 files changed

+163
-38
lines changed

rust_ingest/src/ingest.rs

Lines changed: 145 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ pub struct IngestConfig {
6565
pub max_chars: usize,
6666
/// Maximum tokens for embedding requests.
6767
pub max_tokens: usize,
68+
/// Maximum number of files to process concurrently.
69+
/// If None, defaults to the number of CPU cores.
70+
pub max_concurrent_files: Option<usize>,
6871
}
6972

7073
impl Default for IngestConfig {
@@ -73,6 +76,7 @@ impl Default for IngestConfig {
7376
root_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
7477
max_chars: MAX_FILE_CHARS,
7578
max_tokens: MAX_EMBEDDING_TOKENS,
79+
max_concurrent_files: None, // Use CPU core count by default
7680
}
7781
}
7882
}
@@ -168,14 +172,44 @@ fn create_hnsw_index() -> Hnsw<'static, f32, DistCosine> {
168172
)
169173
}
170174

171-
/// Processes all files in the directory tree.
175+
/// Processes all files in the directory tree with parallel execution.
172176
async fn process_directory_tree(
173177
config: &IngestConfig,
174178
client: &reqwest::Client,
175179
index: &Hnsw<'_, f32, DistCosine>,
176180
file_metadata: &mut Vec<PathBuf>,
177181
stats: &mut IngestStats,
178182
) -> Result<()> {
183+
use std::sync::atomic::{AtomicUsize, Ordering};
184+
use std::sync::{Arc, Mutex};
185+
use tokio::sync::Semaphore;
186+
use tokio::task::JoinSet;
187+
188+
// Create a semaphore to limit concurrent operations
189+
let max_concurrent = config.max_concurrent_files.unwrap_or_else(|| {
190+
// Default to number of CPUs if not specified
191+
std::thread::available_parallelism()
192+
.map(|n| n.get())
193+
.unwrap_or(4)
194+
});
195+
196+
// Use a semaphore to limit concurrent embedding operations
197+
let semaphore = Arc::new(Semaphore::new(max_concurrent));
198+
199+
// Use JoinSet to manage async tasks
200+
let mut tasks = JoinSet::new();
201+
202+
// Use atomics for thread-safe counters
203+
let processed_count = Arc::new(AtomicUsize::new(0));
204+
let skipped_count = Arc::new(AtomicUsize::new(0));
205+
206+
// Use a mutex to protect the file metadata vector
207+
let file_paths = Arc::new(Mutex::new(Vec::new()));
208+
209+
// Collect candidate files first
210+
let mut candidate_files = Vec::new();
211+
212+
// First pass: find all valid files to process
179213
for entry in WalkDir::new(&config.root_dir)
180214
.into_iter()
181215
.filter_map(|e| e.ok())
@@ -192,29 +226,82 @@ async fn process_directory_tree(
192226
continue;
193227
}
194228

195-
match process_single_file(
196-
path,
197-
config,
198-
client,
199-
index,
200-
file_metadata,
201-
stats.files_processed,
202-
)
203-
.await
204-
{
205-
Ok(()) => {
206-
stats.files_processed += 1;
207-
if stats.files_processed % PROGRESS_INTERVAL == 0 {
208-
println!("Processed {} files…", stats.files_processed);
229+
// Add to candidates
230+
candidate_files.push(path.to_path_buf());
231+
}
232+
233+
println!("Found {} files to process", candidate_files.len());
234+
235+
// Second pass: process files concurrently
236+
for (file_id, path) in candidate_files.into_iter().enumerate() {
237+
// Clone references for the async task
238+
let semaphore_clone = semaphore.clone();
239+
let client_clone = client.clone();
240+
let config_clone = config.clone();
241+
let processed_count_clone = processed_count.clone();
242+
let skipped_count_clone = skipped_count.clone();
243+
let file_paths_clone = file_paths.clone();
244+
let path_clone = path.clone();
245+
246+
// Spawn a task for each file
247+
tasks.spawn(async move {
248+
// Acquire a permit from the semaphore
249+
let _permit = semaphore_clone.acquire().await.unwrap();
250+
251+
// Process the file
252+
match process_single_file_for_embedding(&path_clone, &config_clone, &client_clone).await
253+
{
254+
Ok(embedding) => {
255+
// Successfully processed
256+
let count = processed_count_clone.fetch_add(1, Ordering::SeqCst) + 1;
257+
258+
// Store result
259+
let mut metadata = file_paths_clone.lock().unwrap();
260+
metadata.push((file_id, path_clone, embedding));
261+
262+
// Show progress periodically
263+
if count % PROGRESS_INTERVAL == 0 {
264+
println!("Processed {} files…", count);
265+
}
266+
267+
Ok(())
268+
}
269+
Err(e) => {
270+
// Log error and count as skipped
271+
eprintln!(
272+
"Warning: Failed to process file {}: {}",
273+
path_clone.display(),
274+
e
275+
);
276+
skipped_count_clone.fetch_add(1, Ordering::SeqCst);
277+
Err(e)
209278
}
210279
}
211-
Err(e) => {
212-
eprintln!("Warning: Failed to process file {}: {}", path.display(), e);
213-
stats.files_skipped += 1;
214-
}
215-
}
280+
});
216281
}
217282

283+
// Wait for all tasks to complete
284+
while let Some(result) = tasks.join_next().await {
285+
// Just check for panics, errors are already handled in the task
286+
result?;
287+
}
288+
289+
// Update stats from atomic counters
290+
stats.files_processed += processed_count.load(Ordering::SeqCst);
291+
stats.files_skipped += skipped_count.load(Ordering::SeqCst);
292+
293+
// Sort results by file_id and insert into the index
294+
let mut results = file_paths.lock().unwrap();
295+
results.sort_by_key(|(id, _, _)| *id);
296+
297+
// Now populate the index and metadata
298+
for (_, path, embedding) in results.iter() {
299+
let file_id = file_metadata.len();
300+
index.insert((embedding.as_slice(), file_id));
301+
file_metadata.push(path.clone());
302+
}
303+
304+
println!("Successfully indexed {} files", file_metadata.len());
218305
Ok(())
219306
}
220307

@@ -232,29 +319,26 @@ fn is_supported_file(path: &std::path::Path) -> bool {
232319
.is_some_and(|ext| SUPPORTED_EXTENSIONS.contains(&ext))
233320
}
234321

235-
/// Processes a single file and adds it to the index.
322+
/// Processes a single file for embedding without modifying the index.
323+
///
324+
/// This function handles just the embedding part, making it suitable for
325+
/// parallel processing in our async pipeline.
236326
///
237327
/// # Arguments
238328
/// * `path` - Path to the file being processed
239329
/// * `config` - Configuration settings for ingestion
240330
/// * `client` - HTTP client for embedding API requests
241-
/// * `index` - HNSW index to insert embeddings into
242-
/// * `file_metadata` - Collection of file paths to track processed files
243-
/// * `file_id` - Unique identifier for this file in the index
244331
///
245332
/// # Returns
246-
/// Success if the file was processed and added to the index.
333+
/// The embedding vector on success
247334
///
248335
/// # Errors
249336
/// Returns error if file reading or embedding generation fails.
250-
async fn process_single_file(
337+
async fn process_single_file_for_embedding(
251338
path: &std::path::Path,
252339
config: &IngestConfig,
253340
client: &reqwest::Client,
254-
index: &Hnsw<'_, f32, DistCosine>,
255-
file_metadata: &mut Vec<PathBuf>,
256-
file_id: usize,
257-
) -> Result<()> {
341+
) -> Result<Vec<f32>> {
258342
// Read and truncate file content
259343
let content = std::fs::read_to_string(path)
260344
.with_context(|| format!("Failed to read file: {}", path.display()))?;
@@ -266,6 +350,37 @@ async fn process_single_file(
266350
.await
267351
.with_context(|| format!("Failed to generate embedding for file: {}", path.display()))?;
268352

353+
Ok(embedding)
354+
}
355+
356+
/// Processes a single file and adds it to the index.
357+
///
358+
/// # Arguments
359+
/// * `path` - Path to the file being processed
360+
/// * `config` - Configuration settings for ingestion
361+
/// * `client` - HTTP client for embedding API requests
362+
/// * `index` - HNSW index to insert embeddings into
363+
/// * `file_metadata` - Collection of file paths to track processed files
364+
/// * `file_id` - Unique identifier for this file in the index
365+
///
366+
/// # Returns
367+
/// Success if the file was processed and added to the index.
368+
///
369+
/// # Errors
370+
/// Returns error if file reading or embedding generation fails.
371+
///
372+
/// @deprecated Use the parallel processing pipeline instead
373+
async fn process_single_file(
374+
path: &std::path::Path,
375+
config: &IngestConfig,
376+
client: &reqwest::Client,
377+
index: &Hnsw<'_, f32, DistCosine>,
378+
file_metadata: &mut Vec<PathBuf>,
379+
file_id: usize,
380+
) -> Result<()> {
381+
// Generate embedding
382+
let embedding = process_single_file_for_embedding(path, config, client).await?;
383+
269384
// Insert into index
270385
index.insert((embedding.as_slice(), file_id));
271386

rust_ingest/src/main.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ enum Commands {
4848
/// Maximum tokens for embedding requests
4949
#[arg(long, default_value = "600")]
5050
max_tokens: usize,
51+
52+
/// Maximum number of files to process concurrently
53+
/// If not specified, defaults to number of CPU cores
54+
#[arg(long)]
55+
max_concurrent: Option<usize>,
5156
},
5257

5358
/// Ask a question using the pre-built index.
@@ -88,6 +93,7 @@ async fn main() -> Result<()> {
8893
root,
8994
max_chars,
9095
max_tokens,
96+
max_concurrent,
9197
} => {
9298
// Create ingest configuration
9399
let mut config = rust_ingest::ingest::IngestConfig::default();
@@ -97,6 +103,7 @@ async fn main() -> Result<()> {
97103
}
98104
config.max_chars = max_chars;
99105
config.max_tokens = max_tokens;
106+
config.max_concurrent_files = max_concurrent;
100107

101108
// Run ingestion process
102109
let stats = ingest::run_with_config(config).await?;

rust_ingest/src/query.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,24 @@ fn load_index_and_metadata(
159159
) -> Result<(Hnsw<'static, f32, DistCosine>, Vec<PathBuf>)> {
160160
let data_dir = config.root_dir.join("data");
161161

162-
// Load the HNSW index using the correct API with HnswIo
163-
let mut hnsw_loader = HnswIo::new(&data_dir, "index");
164-
let loaded_index: Hnsw<'_, f32, DistCosine> = hnsw_loader
165-
.load_hnsw()
162+
// Load the HNSW index using the modern file_load method
163+
let index_path = data_dir.join("index");
164+
let loaded_index: Hnsw<'_, f32, DistCosine> = HnswIo::file_load(&index_path)
166165
.context("Failed to load HNSW index - ensure ingestion has been run")?;
167166

168167
// Convert the loaded index to an owned index with 'static lifetime
169168
// SAFETY: This transmute extends the lifetime of the index, which is safe
170169
// because we're taking ownership of the index and ensuring it outlives
171170
// its original borrow
172-
let index: Hnsw<'static, f32, DistCosine> = unsafe { std::mem::transmute(loaded_index) }; // Load file metadata
173-
let metadata_file =
174-
fs::File::open(data_dir.join("meta.json")).context("Failed to open metadata file")?;
171+
let index: Hnsw<'static, f32, DistCosine> = unsafe { std::mem::transmute(loaded_index) };
172+
173+
// Load file metadata
174+
let metadata_path = data_dir.join("meta.json");
175+
let metadata_content = fs::read_to_string(&metadata_path)
176+
.with_context(|| format!("Failed to read metadata file: {}", metadata_path.display()))?;
177+
175178
let metadata: Vec<PathBuf> =
176-
serde_json::from_reader(metadata_file).context("Failed to parse metadata JSON")?;
179+
serde_json::from_str(&metadata_content).context("Failed to parse metadata JSON")?;
177180

178181
Ok((index, metadata))
179182
}

0 commit comments

Comments
 (0)