Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions chromadb/test/api/test_schema_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,64 @@ def test_schema_vector_config_persistence(

collection_name = f"schema_spann_{uuid4().hex}"

schema = Schema()
schema.create_index(
config=VectorIndexConfig(
space="cosine",
spann=SpannIndexConfig(
search_nprobe=16,
write_nprobe=32,
ef_construction=120,
max_neighbors=24,
),
)
)

collection = client.get_or_create_collection(
name=collection_name,
schema=schema,
)

persisted_schema = collection.schema
assert persisted_schema is not None

print(persisted_schema.serialize_to_json())

embedding_override = persisted_schema.keys["#embedding"].float_list
assert embedding_override is not None
vector_index = embedding_override.vector_index
assert vector_index is not None
assert vector_index.enabled is True
assert vector_index.config is not None
assert vector_index.config.space is not None
assert vector_index.config.space == "cosine"

client_reloaded = client_factories.create_client_from_system()
reloaded_collection = client_reloaded.get_collection(
name=collection_name,
)

reloaded_schema = reloaded_collection.schema
assert reloaded_schema is not None
reloaded_embedding_override = reloaded_schema.keys["#embedding"].float_list
assert reloaded_embedding_override is not None
reloaded_vector_index = reloaded_embedding_override.vector_index
assert reloaded_vector_index is not None
assert reloaded_vector_index.config is not None
assert reloaded_vector_index.config.space is not None
assert reloaded_vector_index.config.space == "cosine"


def test_schema_vector_config_persistence_with_ef(
client_factories: "ClientFactories",
) -> None:
"""Ensure schema-provided SPANN settings persist across client restarts."""

client = client_factories.create_client_from_system()
client.reset()

collection_name = f"schema_spann_{uuid4().hex}"

schema = Schema()
embedding_function = SimpleEmbeddingFunction(dim=6)
schema.create_index(
Expand Down
67 changes: 66 additions & 1 deletion rust/types/src/collection_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1547,8 +1547,13 @@ impl Schema {
if vector_index.enabled {
return false;
}
if !is_embedding_function_default(&vector_index.config.embedding_function) {
return false;
}
if !is_space_default(&vector_index.config.space) {
return false;
}
// Check that the config has default structure
// We allow space and embedding_function to vary, but check structure
if vector_index.config.source_key.is_some() {
return false;
}
Expand Down Expand Up @@ -1613,6 +1618,9 @@ impl Schema {
if !vector_index.enabled {
return false;
}
if !is_space_default(&vector_index.config.space) {
return false;
}
// Check that embedding_function is default
if !is_embedding_function_default(&vector_index.config.embedding_function) {
return false;
Expand Down Expand Up @@ -3804,6 +3812,63 @@ mod tests {
assert!(!schema_with_extra_overrides.is_default());
}

#[test]
fn test_is_schema_default_with_space() {
let schema = Schema::new_default(KnnIndex::Hnsw);
assert!(schema.is_default());

let mut schema_with_space = Schema::new_default(KnnIndex::Hnsw);
if let Some(ref mut float_list) = schema_with_space.defaults.float_list {
if let Some(ref mut vector_index) = float_list.vector_index {
vector_index.config.space = Some(Space::Cosine);
}
}
assert!(!schema_with_space.is_default());

let mut schema_with_space_in_embedding_key = Schema::new_default(KnnIndex::Spann);
if let Some(ref mut embedding_key) = schema_with_space_in_embedding_key
.keys
.get_mut(EMBEDDING_KEY)
{
if let Some(ref mut float_list) = embedding_key.float_list {
if let Some(ref mut vector_index) = float_list.vector_index {
vector_index.config.space = Some(Space::Cosine);
}
}
}
assert!(!schema_with_space_in_embedding_key.is_default());
}

#[test]
fn test_is_schema_default_with_embedding_function() {
let schema = Schema::new_default(KnnIndex::Hnsw);
assert!(schema.is_default());

let mut schema_with_embedding_function = Schema::new_default(KnnIndex::Hnsw);
if let Some(ref mut float_list) = schema_with_embedding_function.defaults.float_list {
if let Some(ref mut vector_index) = float_list.vector_index {
vector_index.config.embedding_function =
Some(EmbeddingFunctionConfiguration::Legacy);
}
}
assert!(!schema_with_embedding_function.is_default());

let mut schema_with_embedding_function_in_embedding_key =
Schema::new_default(KnnIndex::Spann);
if let Some(ref mut embedding_key) = schema_with_embedding_function_in_embedding_key
.keys
.get_mut(EMBEDDING_KEY)
{
if let Some(ref mut float_list) = embedding_key.float_list {
if let Some(ref mut vector_index) = float_list.vector_index {
vector_index.config.embedding_function =
Some(EmbeddingFunctionConfiguration::Legacy);
}
}
}
assert!(!schema_with_embedding_function_in_embedding_key.is_default());
}

#[test]
fn test_add_merges_keys_by_value_type() {
let mut schema_a = Schema::new_default(KnnIndex::Hnsw);
Expand Down
Loading