Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 109 additions & 46 deletions rust/index/src/spann/quantized_spann.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::sync::{
};

use chroma_blockstore::{
arrow::provider::BlockfileReaderOptions, BlockfileFlusher, BlockfileReader,
BlockfileWriterOptions,
arrow::provider::BlockfileReaderOptions, provider::BlockfileProvider, BlockfileFlusher,
BlockfileReader, BlockfileWriterOptions,
};
use chroma_distance::{normalize, DistanceFunction};
use chroma_error::{ChromaError, ErrorCodes};
Expand All @@ -26,16 +26,14 @@ use faer::{
Mat,
};
use simsimd::SpatialSimilarity;

use thiserror::Error;

use chroma_blockstore::provider::BlockfileProvider;
use uuid::Uuid;

use crate::{
quantization::Code,
spann::{types::QuantizedSpannIds, utils},
spann::utils,
usearch::{USearchIndex, USearchIndexConfig, USearchIndexProvider},
OpenMode, SearchResult, VectorIndex, VectorIndexProvider,
IndexUuid, OpenMode, SearchResult, VectorIndex, VectorIndexProvider,
};

// Blockfile prefixes
Expand All @@ -58,6 +56,16 @@ struct QuantizedDelta {
versions: Vec<u32>,
}

#[derive(Clone, Debug)]
pub struct QuantizedSpannIds {
pub embedding_metadata_id: Uuid,
pub prefix_path: String,
pub quantized_centroid_id: IndexUuid,
pub quantized_cluster_id: Uuid,
pub raw_centroid_id: IndexUuid,
pub scalar_metadata_id: Uuid,
}

#[derive(Error, Debug)]
pub enum QuantizedSpannError {
#[error("Centroid index error: {0}")]
Expand All @@ -79,8 +87,10 @@ impl ChromaError for QuantizedSpannError {
}

/// Mutable quantized SPANN index, generic over centroid index.
#[derive(Clone)]
pub struct QuantizedSpannIndexWriter<I: VectorIndex> {
// === Config ===
cluster_block_size: usize,
cmek: Option<Cmek>,
collection_id: CollectionUuid,
config: SpannIndexConfig,
Expand Down Expand Up @@ -798,39 +808,17 @@ impl QuantizedSpannIndexWriter<USearchIndex> {
/// Commit all in-memory state to blockfile writers and return a flusher.
///
/// This method consumes the index and prepares all data for persistence.
/// Call `flush()` on the returned flusher to actually write to storage.
/// Call `finish()` before this method, then `flush()` on the returned
/// flusher to actually write to storage.
pub async fn commit(
mut self,
self,
blockfile_provider: &BlockfileProvider,
usearch_provider: &USearchIndexProvider,
) -> Result<QuantizedSpannFlusher, QuantizedSpannError> {
// === Step 0: Pre-scrub cleanup ===
let mut mutated_cluster_ids = self
.cluster_deltas
.iter()
.filter_map(|entry| (!entry.value().ids.is_empty()).then_some(*entry.key()))
.collect::<Vec<_>>();

for cluster_id in &mutated_cluster_ids {
self.scrub(*cluster_id).await?;
}

let zero_length_cluster_ids = self
.cluster_deltas
.iter()
.filter_map(|entry| (entry.value().length == 0).then_some(*entry.key()))
.collect::<Vec<_>>();

for cluster_id in zero_length_cluster_ids {
self.drop(cluster_id).await?;
}

// === Step 1: Check center drift and rebuild centroid indexes if needed ===
self.rebuild_on_drift(usearch_provider).await?;

// === Step 2: Create blockfile writers ===
let mut qc_options =
BlockfileWriterOptions::new(self.prefix_path.clone()).ordered_mutations();
// === Create blockfile writers ===
let mut qc_options = BlockfileWriterOptions::new(self.prefix_path.clone())
.ordered_mutations()
.max_block_size_bytes(self.cluster_block_size);
let mut sm_options =
BlockfileWriterOptions::new(self.prefix_path.clone()).ordered_mutations();
let mut em_options =
Expand Down Expand Up @@ -862,9 +850,14 @@ impl QuantizedSpannIndexWriter<USearchIndex> {
.await
.map_err(|err| QuantizedSpannError::Blockfile(err.boxed()))?;

// === Step 3: Write quantized_cluster data ===
// === Write quantized_cluster data ===
let quantized_cluster_flusher = {
// Add tombstoned cluster ids (need to delete from blockfile)
// Collect clusters that received mutations plus tombstones.
let mut mutated_cluster_ids = self
.cluster_deltas
.iter()
.filter_map(|entry| (!entry.value().ids.is_empty()).then_some(*entry.key()))
.collect::<Vec<_>>();
for cluster_id in self.tombstones.iter() {
mutated_cluster_ids.push(*cluster_id);
}
Expand Down Expand Up @@ -905,7 +898,7 @@ impl QuantizedSpannIndexWriter<USearchIndex> {
.map_err(QuantizedSpannError::Blockfile)?
};

// === Step 4: Write scalar_metadata ===
// === Write scalar_metadata ===
let scalar_metadata_flusher = {
// 1. PREFIX_LENGTH - sorted by cluster_id
let mut lengths = self
Expand Down Expand Up @@ -948,7 +941,7 @@ impl QuantizedSpannIndexWriter<USearchIndex> {
.map_err(QuantizedSpannError::Blockfile)?
};

// === Step 5: Write embedding_metadata ===
// === Write embedding_metadata ===
let embedding_metadata_flusher = {
// 1. PREFIX_CENTER - quantization center (always write, may be updated)
embedding_metadata_writer
Expand Down Expand Up @@ -978,6 +971,7 @@ impl QuantizedSpannIndexWriter<USearchIndex> {

Ok(QuantizedSpannFlusher {
embedding_metadata_flusher,
prefix_path: self.prefix_path.clone(),
quantized_centroid: self.quantized_centroid,
quantized_cluster_flusher,
raw_centroid: self.raw_centroid,
Expand All @@ -987,7 +981,9 @@ impl QuantizedSpannIndexWriter<USearchIndex> {
}

/// Create a new quantized SPANN index.
#[allow(clippy::too_many_arguments)]
pub async fn create(
cluster_block_size: usize,
collection_id: CollectionUuid,
config: SpannIndexConfig,
dimension: usize,
Expand Down Expand Up @@ -1041,6 +1037,7 @@ impl QuantizedSpannIndexWriter<USearchIndex> {

Ok(Self {
// === Config ===
cluster_block_size,
cmek,
collection_id,
config,
Expand Down Expand Up @@ -1068,9 +1065,46 @@ impl QuantizedSpannIndexWriter<USearchIndex> {
})
}

/// Prepare the index for commit: scrub mutated clusters, drop empty
/// clusters, and rebuild centroid indexes if the quantization center
/// has drifted. Must be called before `commit()`.
pub async fn finish(
&mut self,
usearch_provider: &USearchIndexProvider,
) -> Result<(), QuantizedSpannError> {
// Scrub all clusters that received mutations.
let mut mutated_cluster_ids = self
.cluster_deltas
.iter()
.filter_map(|entry| (!entry.value().ids.is_empty()).then_some(*entry.key()))
.collect::<Vec<_>>();
mutated_cluster_ids.sort_unstable();

for cluster_id in &mutated_cluster_ids {
self.scrub(*cluster_id).await?;
}

// Drop clusters that ended up empty after scrubbing.
let zero_length_cluster_ids = self
.cluster_deltas
.iter()
.filter_map(|entry| (entry.value().length == 0).then_some(*entry.key()))
.collect::<Vec<_>>();

for cluster_id in zero_length_cluster_ids {
self.drop(cluster_id).await?;
}

// Check center drift and rebuild centroid indexes if needed.
self.rebuild_on_drift(usearch_provider).await?;

Ok(())
}

/// Open an existing quantized SPANN index from file IDs.
#[allow(clippy::too_many_arguments)]
pub async fn open(
cluster_block_size: usize,
collection_id: CollectionUuid,
config: SpannIndexConfig,
dimension: usize,
Expand Down Expand Up @@ -1234,6 +1268,7 @@ impl QuantizedSpannIndexWriter<USearchIndex> {

Ok(Self {
// === Config ===
cluster_block_size,
cmek,
collection_id,
config,
Expand Down Expand Up @@ -1353,6 +1388,7 @@ impl QuantizedSpannIndexWriter<USearchIndex> {
/// Flusher for persisting a quantized SPANN index to storage.
pub struct QuantizedSpannFlusher {
embedding_metadata_flusher: BlockfileFlusher,
prefix_path: String,
quantized_centroid: USearchIndex,
quantized_cluster_flusher: BlockfileFlusher,
raw_centroid: USearchIndex,
Expand Down Expand Up @@ -1397,6 +1433,7 @@ impl QuantizedSpannFlusher {
// Return file IDs
Ok(QuantizedSpannIds {
embedding_metadata_id,
prefix_path: self.prefix_path.clone(),
quantized_centroid_id,
quantized_cluster_id,
raw_centroid_id,
Expand Down Expand Up @@ -1431,6 +1468,7 @@ mod tests {

use super::{QuantizedDelta, QuantizedSpannIndexWriter};

const TEST_CLUSTER_BLOCK_SIZE: usize = 2 * 1024 * 1024;
const TEST_DIMENSION: usize = 4;
const TEST_EPSILON: f32 = 1e-5;

Expand Down Expand Up @@ -1491,6 +1529,7 @@ mod tests {
let usearch_provider = test_usearch_provider(storage);

let writer = QuantizedSpannIndexWriter::<USearchIndex>::create(
TEST_CLUSTER_BLOCK_SIZE,
CollectionUuid::new(),
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -1741,7 +1780,8 @@ mod tests {
// =======================================================================
// Phase 1: Create index, add points, commit, flush
// =======================================================================
let writer = QuantizedSpannIndexWriter::<USearchIndex>::create(
let mut writer = QuantizedSpannIndexWriter::<USearchIndex>::create(
TEST_CLUSTER_BLOCK_SIZE,
collection_id,
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -1784,7 +1824,11 @@ mod tests {
let expected_rotated_101 = writer.rotate(&[0.0, 1.0, 0.0, 0.0]);
let expected_rotated_102 = writer.rotate(&[0.0, 0.0, 1.0, 0.0]);

// Commit and flush
// Finish, commit, and flush
writer
.finish(&usearch_provider)
.await
.expect("Failed to finish");
let flusher = writer
.commit(&blockfile_provider, &usearch_provider)
.await
Expand All @@ -1807,6 +1851,7 @@ mod tests {
.expect("Failed to open raw embedding reader");

let writer = QuantizedSpannIndexWriter::<USearchIndex>::open(
TEST_CLUSTER_BLOCK_SIZE,
collection_id,
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -2016,6 +2061,7 @@ mod tests {
let usearch_provider = test_usearch_provider(storage);

let writer = QuantizedSpannIndexWriter::<USearchIndex>::create(
TEST_CLUSTER_BLOCK_SIZE,
CollectionUuid::new(),
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -2272,7 +2318,7 @@ mod tests {
}

#[tokio::test]
async fn test_open_commit() {
async fn test_open_finish_commit() {
// =======================================================================
// Setup
// =======================================================================
Expand All @@ -2282,7 +2328,8 @@ mod tests {
let usearch_provider = test_usearch_provider(storage.clone());
let collection_id = CollectionUuid::new();

let writer = QuantizedSpannIndexWriter::<USearchIndex>::create(
let mut writer = QuantizedSpannIndexWriter::<USearchIndex>::create(
TEST_CLUSTER_BLOCK_SIZE,
collection_id,
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -2371,6 +2418,10 @@ mod tests {

let next_cluster_id_after_spawn = writer.next_cluster_id.load(Ordering::Relaxed);

writer
.finish(&usearch_provider)
.await
.expect("Failed to finish");
let flusher = writer
.commit(&blockfile_provider, &usearch_provider)
.await
Expand All @@ -2383,7 +2434,8 @@ mod tests {
let blockfile_provider = test_blockfile_provider(storage.clone());
let usearch_provider = test_usearch_provider(storage.clone());

let writer = QuantizedSpannIndexWriter::<USearchIndex>::open(
let mut writer = QuantizedSpannIndexWriter::<USearchIndex>::open(
TEST_CLUSTER_BLOCK_SIZE,
collection_id,
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -2424,6 +2476,10 @@ mod tests {
// --- Cluster D: drop ---
writer.drop(cluster_d).await.expect("drop D failed");

writer
.finish(&usearch_provider)
.await
.expect("Failed to finish");
let flusher = writer
.commit(&blockfile_provider, &usearch_provider)
.await
Expand All @@ -2437,6 +2493,7 @@ mod tests {
let usearch_provider = test_usearch_provider(storage.clone());

let writer = QuantizedSpannIndexWriter::<USearchIndex>::open(
TEST_CLUSTER_BLOCK_SIZE,
collection_id,
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -2616,6 +2673,7 @@ mod tests {
let usearch_provider = test_usearch_provider(storage.clone());

let mut writer = QuantizedSpannIndexWriter::<USearchIndex>::create(
TEST_CLUSTER_BLOCK_SIZE,
collection_id,
test_config(),
TEST_DIMENSION,
Expand Down Expand Up @@ -2697,10 +2755,14 @@ mod tests {
let test_vec = [rng.gen(), rng.gen(), rng.gen(), rng.gen()];
let expected_rotated = writer.rotate(&test_vec);

// --- Commit + Flush ---
// --- Finish + Commit + Flush ---
let blockfile_provider = test_blockfile_provider(storage.clone());
let usearch_provider = test_usearch_provider(storage.clone());

writer
.finish(&usearch_provider)
.await
.expect("finish failed");
let flusher = writer
.commit(&blockfile_provider, &usearch_provider)
.await
Expand All @@ -2720,6 +2782,7 @@ mod tests {
.expect("Failed to open raw embedding reader");

writer = QuantizedSpannIndexWriter::<USearchIndex>::open(
TEST_CLUSTER_BLOCK_SIZE,
collection_id,
test_config(),
TEST_DIMENSION,
Expand Down
9 changes: 0 additions & 9 deletions rust/index/src/spann/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2510,15 +2510,6 @@ pub struct SpannIndexIds {
pub prefix_path: String,
}

#[derive(Clone, Debug)]
pub struct QuantizedSpannIds {
pub embedding_metadata_id: Uuid,
pub quantized_centroid_id: IndexUuid,
pub quantized_cluster_id: Uuid,
pub raw_centroid_id: IndexUuid,
pub scalar_metadata_id: Uuid,
}

impl SpannIndexFlusher {
pub async fn flush(self) -> Result<SpannIndexIds, SpannIndexWriterError> {
let res = SpannIndexIds {
Expand Down
Loading
Loading