From a021c46ef8adba9491f190284731e0ade6e23033 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan Date: Mon, 3 Nov 2025 15:55:44 -0800 Subject: [PATCH] [TST] add schema proptest --- chromadb/test/property/strategies.py | 390 ++++++++++++- chromadb/test/property/test_schema.py | 764 ++++++++++++++++++++++++++ 2 files changed, 1149 insertions(+), 5 deletions(-) create mode 100644 chromadb/test/property/test_schema.py diff --git a/chromadb/test/property/strategies.py b/chromadb/test/property/strategies.py index 5ed57ebf44c..6512090a5fa 100644 --- a/chromadb/test/property/strategies.py +++ b/chromadb/test/property/strategies.py @@ -1,7 +1,7 @@ import hashlib import hypothesis import hypothesis.strategies as st -from typing import Any, Optional, List, Dict, Union, cast +from typing import Any, Optional, List, Dict, Union, cast, Tuple from typing_extensions import TypedDict import uuid import numpy as np @@ -9,6 +9,10 @@ import chromadb.api.types as types import re from hypothesis.strategies._internal.strategies import SearchStrategy +from chromadb.test.api.test_schema_e2e import ( + SimpleEmbeddingFunction, + DeterministicSparseEmbeddingFunction, +) from chromadb.test.conftest import NOT_CLUSTER_ONLY from dataclasses import dataclass from chromadb.api.types import ( @@ -17,12 +21,27 @@ EmbeddingFunction, Embeddings, Metadata, + Schema, + CollectionMetadata, + VectorIndexConfig, + SparseVectorIndexConfig, + StringInvertedIndexConfig, + IntInvertedIndexConfig, + FloatInvertedIndexConfig, + BoolInvertedIndexConfig, + HnswIndexConfig, + SpannIndexConfig, + Space, ) from chromadb.types import LiteralValue, WhereOperator, LogicalOperator -from chromadb.test.conftest import is_spann_disabled_mode, skip_reason_spann_disabled +from chromadb.test.conftest import is_spann_disabled_mode from chromadb.api.collection_configuration import ( CreateCollectionConfiguration, CreateSpannConfiguration, + CreateHNSWConfiguration, +) +from chromadb.utils.embedding_functions import ( + register_embedding_function, ) # Set the random seed for reproducibility @@ -266,6 +285,365 @@ class ExternalCollection: embedding_function: Optional[types.EmbeddingFunction[Embeddable]] +@register_embedding_function +class SimpleIpEmbeddingFunction(SimpleEmbeddingFunction): + """Simple embedding function with ip space for persistence tests.""" + + def default_space(self) -> str: # type: ignore[override] + return "ip" + + +@st.composite +def vector_index_config_strategy(draw: st.DrawFn) -> VectorIndexConfig: + """Generate VectorIndexConfig with optional space, embedding_function, source_key, hnsw, spann.""" + space = None + embedding_function = None + source_key = None + hnsw = None + spann = None + + if draw(st.booleans()): + space = draw(st.sampled_from(["cosine", "l2", "ip"])) + + if draw(st.booleans()): + embedding_function = SimpleIpEmbeddingFunction( + dim=draw(st.integers(min_value=1, max_value=1000)) + ) + + if draw(st.booleans()): + source_key = draw(st.one_of(st.just("#document"), safe_text)) + + index_choice = draw(st.sampled_from(["hnsw", "spann", "none"])) + + if index_choice == "hnsw": + hnsw = HnswIndexConfig( + ef_construction=draw(st.integers(min_value=1, max_value=1000)) + if draw(st.booleans()) + else None, + max_neighbors=draw(st.integers(min_value=1, max_value=1000)) + if draw(st.booleans()) + else None, + ef_search=draw(st.integers(min_value=1, max_value=1000)) + if draw(st.booleans()) + else None, + sync_threshold=draw(st.integers(min_value=2, max_value=10000)) + if draw(st.booleans()) + else None, + resize_factor=draw(st.floats(min_value=1.0, max_value=5.0)) + if draw(st.booleans()) + else None, + ) + elif index_choice == "spann": + spann = SpannIndexConfig( + search_nprobe=draw(st.integers(min_value=1, max_value=128)) + if draw(st.booleans()) + else None, + write_nprobe=draw(st.integers(min_value=1, max_value=64)) + if draw(st.booleans()) + else None, + ef_construction=draw(st.integers(min_value=1, max_value=200)) + if draw(st.booleans()) + else None, + ef_search=draw(st.integers(min_value=1, max_value=200)) + if draw(st.booleans()) + else None, + max_neighbors=draw(st.integers(min_value=1, max_value=64)) + if draw(st.booleans()) + else None, + reassign_neighbor_count=draw(st.integers(min_value=1, max_value=64)) + if draw(st.booleans()) + else None, + split_threshold=draw(st.integers(min_value=50, max_value=200)) + if draw(st.booleans()) + else None, + merge_threshold=draw(st.integers(min_value=25, max_value=100)) + if draw(st.booleans()) + else None, + ) + + return VectorIndexConfig( + space=cast(Space, space), + embedding_function=embedding_function, + source_key=source_key, + hnsw=hnsw, + spann=spann, + ) + + +@st.composite +def sparse_vector_index_config_strategy(draw: st.DrawFn) -> SparseVectorIndexConfig: + """Generate SparseVectorIndexConfig with optional embedding_function, source_key, bm25.""" + embedding_function = None + source_key = None + bm25 = None + + if draw(st.booleans()): + embedding_function = DeterministicSparseEmbeddingFunction() + source_key = draw(st.one_of(st.just("#document"), safe_text)) + + if draw(st.booleans()): + bm25 = draw(st.booleans()) + + return SparseVectorIndexConfig( + embedding_function=embedding_function, + source_key=source_key, + bm25=bm25, + ) + + +@st.composite +def schema_strategy(draw: st.DrawFn) -> Optional[Schema]: + """Generate a Schema object with various create_index/delete_index operations.""" + if draw(st.booleans()): + return None + + schema = Schema() + + # Decide how many operations to perform + num_operations = draw(st.integers(min_value=0, max_value=5)) + sparse_index_created = False + + for _ in range(num_operations): + operation = draw(st.sampled_from(["create_index", "delete_index"])) + config_type = draw( + st.sampled_from( + [ + "string_inverted", + "int_inverted", + "float_inverted", + "bool_inverted", + "vector", + "sparse_vector", + ] + ) + ) + + # Decide if we're setting on a key or globally + use_key = draw(st.booleans()) + key = None + if use_key and config_type != "vector": + # Vector indexes can't be set on specific keys, only globally + key = draw(safe_text) + + if operation == "create_index": + if config_type == "string_inverted": + schema.create_index(config=StringInvertedIndexConfig(), key=key) + elif config_type == "int_inverted": + schema.create_index(config=IntInvertedIndexConfig(), key=key) + elif config_type == "float_inverted": + schema.create_index(config=FloatInvertedIndexConfig(), key=key) + elif config_type == "bool_inverted": + schema.create_index(config=BoolInvertedIndexConfig(), key=key) + elif config_type == "vector": + vector_config = draw(vector_index_config_strategy()) + schema.create_index(config=vector_config, key=None) + elif ( + config_type == "sparse_vector" + and not is_spann_disabled_mode + and not sparse_index_created + ): + sparse_config = draw(sparse_vector_index_config_strategy()) + # Sparse vector MUST have a key + if key is None: + key = draw(safe_text) + schema.create_index(config=sparse_config, key=key) + sparse_index_created = True + + elif operation == "delete_index": + if config_type == "string_inverted": + schema.delete_index(config=StringInvertedIndexConfig(), key=key) + elif config_type == "int_inverted": + schema.delete_index(config=IntInvertedIndexConfig(), key=key) + elif config_type == "float_inverted": + schema.delete_index(config=FloatInvertedIndexConfig(), key=key) + elif config_type == "bool_inverted": + schema.delete_index(config=BoolInvertedIndexConfig(), key=key) + # Vector, FTS, and sparse_vector deletion is not currently supported + + return schema + + +@st.composite +def metadata_with_hnsw_strategy(draw: st.DrawFn) -> Optional[CollectionMetadata]: + """Generate metadata with hnsw parameters.""" + metadata: CollectionMetadata = {} + + if draw(st.booleans()): + metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"])) + if draw(st.booleans()): + metadata["hnsw:construction_ef"] = draw( + st.integers(min_value=1, max_value=1000) + ) + if draw(st.booleans()): + metadata["hnsw:search_ef"] = draw(st.integers(min_value=1, max_value=1000)) + if draw(st.booleans()): + metadata["hnsw:M"] = draw(st.integers(min_value=1, max_value=1000)) + if draw(st.booleans()): + metadata["hnsw:resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0)) + if draw(st.booleans()): + metadata["hnsw:sync_threshold"] = draw( + st.integers(min_value=2, max_value=10000) + ) + + return metadata if metadata else None + + +@st.composite +def create_configuration_strategy( + draw: st.DrawFn, +) -> Optional[CreateCollectionConfiguration]: + """Generate CreateCollectionConfiguration with mutual exclusivity rules.""" + configuration: CreateCollectionConfiguration = {} + + # Optionally set embedding_function (independent) + if draw(st.booleans()): + configuration["embedding_function"] = SimpleIpEmbeddingFunction( + dim=draw(st.integers(min_value=1, max_value=1000)) + ) + + # Decide: set space only, OR set hnsw config, OR set spann config + config_choice = draw( + st.sampled_from( + ["space_only_hnsw", "space_only_spann", "hnsw", "spann", "none"] + ) + ) + + if config_choice == "space_only_hnsw": + configuration["hnsw"] = CreateHNSWConfiguration( + space=draw(st.sampled_from(["cosine", "l2", "ip"])) + ) + elif config_choice == "space_only_spann": + configuration["spann"] = CreateSpannConfiguration( + space=draw(st.sampled_from(["cosine", "l2", "ip"])) + ) + elif config_choice == "hnsw": + # Set hnsw config (optionally with space) + hnsw_config: CreateHNSWConfiguration = {} + if draw(st.booleans()): + hnsw_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"])) + hnsw_config["ef_construction"] = draw(st.integers(min_value=1, max_value=1000)) + hnsw_config["ef_search"] = draw(st.integers(min_value=1, max_value=1000)) + hnsw_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=1000)) + hnsw_config["sync_threshold"] = draw(st.integers(min_value=2, max_value=10000)) + hnsw_config["resize_factor"] = draw(st.floats(min_value=1.0, max_value=5.0)) + configuration["hnsw"] = hnsw_config + elif config_choice == "spann": + # Set spann config (optionally with space) + spann_config: CreateSpannConfiguration = {} + if draw(st.booleans()): + spann_config["space"] = draw(st.sampled_from(["cosine", "l2", "ip"])) + spann_config["search_nprobe"] = draw(st.integers(min_value=1, max_value=128)) + spann_config["write_nprobe"] = draw(st.integers(min_value=1, max_value=64)) + spann_config["ef_construction"] = draw(st.integers(min_value=1, max_value=200)) + spann_config["ef_search"] = draw(st.integers(min_value=1, max_value=200)) + spann_config["max_neighbors"] = draw(st.integers(min_value=1, max_value=64)) + spann_config["reassign_neighbor_count"] = draw( + st.integers(min_value=1, max_value=64) + ) + spann_config["split_threshold"] = draw(st.integers(min_value=50, max_value=200)) + spann_config["merge_threshold"] = draw(st.integers(min_value=25, max_value=100)) + configuration["spann"] = spann_config + + return configuration if configuration else None + + +@dataclass +class CollectionInputCombination: + """ + Input tuple for collection creation tests. + """ + + metadata: Optional[CollectionMetadata] + configuration: Optional[CreateCollectionConfiguration] + schema: Optional[Schema] + schema_vector_info: Optional[Dict[str, Any]] + kind: str + + +def non_none_items(items: Dict[str, Any]) -> Dict[str, Any]: + return {k: v for k, v in items.items() if v is not None} + + +def vector_index_to_dict(config: VectorIndexConfig) -> Dict[str, Any]: + embedding_default_space: Optional[str] = None + if config.embedding_function is not None and hasattr( + config.embedding_function, "default_space" + ): + embedding_default_space = cast(str, config.embedding_function.default_space()) + + return { + "space": config.space, + "hnsw": config.hnsw.model_dump(exclude_none=True) if config.hnsw else None, + "spann": config.spann.model_dump(exclude_none=True) if config.spann else None, + "embedding_function_default_space": embedding_default_space, + } + + +@st.composite +def _schema_input_strategy( + draw: st.DrawFn, +) -> Tuple[Schema, Dict[str, Any]]: + schema = Schema() + vector_config = draw(vector_index_config_strategy()) + schema.create_index(config=vector_config, key=None) + return schema, vector_index_to_dict(vector_config) + + +@st.composite +def metadata_configuration_schema_strategy( + draw: st.DrawFn, +) -> CollectionInputCombination: + """ + Generate compatible combinations of metadata, configuration, and schema inputs. + """ + + choice = draw( + st.sampled_from( + [ + "none", + "metadata", + "configuration", + "metadata_configuration", + "schema", + ] + ) + ) + + metadata: Optional[CollectionMetadata] = None + configuration: Optional[CreateCollectionConfiguration] = None + schema: Optional[Schema] = None + schema_info: Optional[Dict[str, Any]] = None + + if choice in ("metadata", "metadata_configuration"): + metadata = draw( + metadata_with_hnsw_strategy().filter( + lambda value: value is not None and len(value) > 0 + ) + ) + + if choice in ("configuration", "metadata_configuration"): + configuration = draw( + create_configuration_strategy().filter( + lambda value: value is not None + and ( + (value.get("hnsw") is not None and len(value["hnsw"]) > 0) + or (value.get("spann") is not None and len(value["spann"]) > 0) + ) + ) + ) + + if choice == "schema": + schema, schema_info = draw(_schema_input_strategy()) + + return CollectionInputCombination( + metadata=metadata, + configuration=configuration, + schema=schema, + schema_vector_info=schema_info, + kind=choice, + ) + + @dataclass class Collection(ExternalCollection): """ @@ -344,7 +722,7 @@ def collections( spann_config: CreateSpannConfiguration = { "space": spann_space, "write_nprobe": 4, - "reassign_neighbor_count": 4 + "reassign_neighbor_count": 4, } collection_config = { "spann": spann_config, @@ -395,7 +773,7 @@ def collections( known_document_keywords=known_document_keywords, has_embeddings=has_embeddings, embedding_function=embedding_function, - collection_config=collection_config + collection_config=collection_config, ) @@ -421,7 +799,9 @@ def metadata( del metadata[key] # type: ignore # Finally, add in some of the known keys for the collection sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = { - k: st.just(v) for k, v in collection.known_metadata_keys.items() + k: st.just(v) + for k, v in collection.known_metadata_keys.items() + if isinstance(v, (str, int, float)) } metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict))) # type: ignore # We don't allow submitting empty metadata diff --git a/chromadb/test/property/test_schema.py b/chromadb/test/property/test_schema.py new file mode 100644 index 00000000000..1ae3ca457d3 --- /dev/null +++ b/chromadb/test/property/test_schema.py @@ -0,0 +1,764 @@ +import math +from typing import Any, Dict, Optional, Set, Tuple, cast + +from hypothesis import given + +from chromadb.api import ClientAPI +from chromadb.api.collection_configuration import CreateCollectionConfiguration +from chromadb.api.types import ( + CollectionMetadata, + EMBEDDING_KEY, + Schema, +) +from chromadb.test.property import strategies +from chromadb.test.property.invariants import check_metadata +from chromadb.test.conftest import ( + reset, + is_spann_disabled_mode, +) + + +HNSW_METADATA_TO_CONFIG: Dict[str, str] = { + "hnsw:space": "space", + "hnsw:construction_ef": "ef_construction", + "hnsw:search_ef": "ef_search", + "hnsw:M": "max_neighbors", + "hnsw:sync_threshold": "sync_threshold", + "hnsw:resize_factor": "resize_factor", +} + +HNSW_FIELDS = [ + "space", + "ef_construction", + "ef_search", + "max_neighbors", + "sync_threshold", + "resize_factor", +] + +HNSW_DEFAULTS: Dict[str, Any] = { + "space": "l2", + "ef_construction": 100, + "ef_search": 100, + "max_neighbors": 16, + "sync_threshold": 1000, + "resize_factor": 1.2, +} + +SPANN_FIELDS = [ + "space", + "search_nprobe", + "write_nprobe", + "ef_construction", + "ef_search", + "max_neighbors", + "reassign_neighbor_count", + "split_threshold", + "merge_threshold", +] + +SPANN_DEFAULTS: Dict[str, Any] = { + "space": "l2", + "search_nprobe": 64, + "write_nprobe": 32, + "ef_construction": 200, + "ef_search": 200, + "max_neighbors": 64, + "reassign_neighbor_count": 64, + "split_threshold": 50, + "merge_threshold": 25, +} + + +def _extract_vector_configs_from_schema( + schema: Schema, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + defaults_float = schema.defaults.float_list + assert defaults_float is not None + defaults_vi = defaults_float.vector_index + assert defaults_vi is not None + + embedding_float = schema.keys[EMBEDDING_KEY].float_list + assert embedding_float is not None + embedding_vi = embedding_float.vector_index + assert embedding_vi is not None + + return ( + strategies.vector_index_to_dict(defaults_vi.config), + strategies.vector_index_to_dict(embedding_vi.config), + ) + + +def _compute_expected_config( + spann_active: bool, + metadata: Optional[CollectionMetadata], + configuration: Optional[CreateCollectionConfiguration], + schema_vector_index_config: Optional[Dict[str, Any]], +) -> Dict[str, Any]: + """ + some assumptions/assertions: + 1. we are not testing failure paths. any config built is/should be valid. invalid cases can be tested separately or in e2e tests + ex: if configuration is set, schema is not set. if schema is set, configuration is not set. both hnsw and spann cannot be set at the same time in config or schema + """ + if spann_active: + # start off creating default spann config, we slowly modify it to much whatever prop test provides + expected = SPANN_DEFAULTS.copy() + space_set = False + # theres some edge cases where space is set in hnsw config and in metadata + # in this case, we check if the space set by config is not the default, and if so, we don't try to get use the one from metadata + # essentially if either metadata or hnsw config provides a non-default space, we use that one, with config hnsw taking priority over metadata + should_try_metadata = True + + if configuration: + spann_cfg = configuration.get("spann") + if spann_cfg: + spann_cfg_dict = cast(Dict[str, Any], spann_cfg) + # update expected with whatever prop test provides + expected.update(strategies.non_none_items(spann_cfg_dict)) + # if space is set in spann, this now takes priority over all else + if spann_cfg_dict.get("space") is not None: + expected["space"] = spann_cfg_dict["space"] + space_set = True + should_try_metadata = False + hnsw_cfg = configuration.get("hnsw") + if hnsw_cfg: + hnsw_cfg_dict = cast(Dict[str, Any], hnsw_cfg) + hnsw_non_none = strategies.non_none_items(hnsw_cfg_dict) + for key, value in hnsw_non_none.items(): + if value is not None and value != HNSW_DEFAULTS[key]: + # if any hnsw config is not the default, we do not use metadata at all, this is used + # heres a sample case where this is needed: hnsw doesnt set space (so l2 by default), but sets ef_construction, metadata sets space to ip + # in this case, they were aware of hnsw config, and chose not to set space in it. therefore the config takes priority over metadata + should_try_metadata = False + # when SPANN is active and HNSW config is provided, use space from hnsw config + if hnsw_cfg_dict.get("space") is not None and not space_set: + # if the space set by config is not the default, don't try to get use the one from metadata + if hnsw_cfg_dict.get("space") != HNSW_DEFAULTS["space"]: + should_try_metadata = False + expected["space"] = hnsw_cfg_dict["space"] + space_set = True + + if schema_vector_index_config: + if schema_vector_index_config.get("space") is not None: + expected["space"] = schema_vector_index_config["space"] + space_set = True + if schema_vector_index_config.get("spann"): + spann_schema = strategies.non_none_items( + schema_vector_index_config["spann"] + ) + expected.update(spann_schema) + + if ( + metadata + and metadata.get("hnsw:space") is not None + and metadata.get("hnsw:space") != SPANN_DEFAULTS["space"] + and should_try_metadata + ): + expected["space"] = metadata["hnsw:space"] + space_set = True + + if ( + schema_vector_index_config + and schema_vector_index_config.get("embedding_function_default_space") + is not None + and schema_vector_index_config.get("embedding_function_default_space") + != SPANN_DEFAULTS["space"] + and not space_set + ): + expected["space"] = schema_vector_index_config[ + "embedding_function_default_space" + ] + space_set = True + + if ( + not space_set + and configuration + and configuration.get("embedding_function") is not None + ): + ef = configuration["embedding_function"] + if hasattr(ef, "default_space"): + expected["space"] = cast(Any, ef).default_space() + + return expected + + expected = HNSW_DEFAULTS.copy() + space_set = False + configured_hnsw_keys: Set[str] = set() + should_try_metadata = True + + if configuration: + hnsw_cfg_raw = configuration.get("hnsw") + if hnsw_cfg_raw is not None: + hnsw_dict: Dict[str, Any] = cast(Dict[str, Any], hnsw_cfg_raw) + hnsw_non_none = strategies.non_none_items(hnsw_dict) + expected.update(hnsw_non_none) + for key, value in hnsw_non_none.items(): + # if any hnsw config is not the default, we do not use metadata at all + if value is not None and value != HNSW_DEFAULTS[key]: + should_try_metadata = False + configured_hnsw_keys.update(hnsw_non_none.keys()) + if hnsw_non_none.get("space") is not None and not space_set: + if hnsw_non_none.get("space") != HNSW_DEFAULTS["space"]: + should_try_metadata = False + space_set = True + spann_cfg_raw = configuration.get("spann") + if spann_cfg_raw is not None: + spann_dict: Dict[str, Any] = cast(Dict[str, Any], spann_cfg_raw) + if spann_dict.get("space") is not None and not space_set: + expected["space"] = spann_dict["space"] + space_set = True + should_try_metadata = False + + if should_try_metadata and metadata: + for key, cfg_key in HNSW_METADATA_TO_CONFIG.items(): + if metadata.get(key) is None: + continue + if cfg_key == "space": + expected[cfg_key] = metadata[key] + space_set = True + configured_hnsw_keys.add(cfg_key) + continue + if cfg_key not in configured_hnsw_keys: + expected[cfg_key] = metadata[key] + configured_hnsw_keys.add(cfg_key) + + if schema_vector_index_config: + if schema_vector_index_config.get("space") is not None: + expected["space"] = schema_vector_index_config["space"] + space_set = True + if schema_vector_index_config.get("hnsw"): + expected.update( + strategies.non_none_items(schema_vector_index_config["hnsw"]) + ) + elif schema_vector_index_config.get("spann"): + # Schema provided SPANN configuration while HNSW is active; ignore. + pass + + if ( + schema_vector_index_config + and schema_vector_index_config.get("embedding_function_default_space") + is not None + and not space_set + ): + expected["space"] = schema_vector_index_config[ + "embedding_function_default_space" + ] + space_set = True + + if ( + not space_set + and configuration + and configuration.get("embedding_function") is not None + ): + ef = configuration["embedding_function"] + if hasattr(ef, "default_space"): + expected["space"] = cast(Any, ef).default_space() + + return expected + + +def _assert_config_values( + actual: Dict[str, Any], + expected: Dict[str, Any], + spann_active: bool, +) -> None: + fields = SPANN_FIELDS if spann_active else HNSW_FIELDS + for field in fields: + actual_value = actual.get(field) + expected_value = expected[field] + # Use approximate equality for floating-point values + if isinstance(actual_value, float) and isinstance(expected_value, float): + assert math.isclose( + actual_value, expected_value, rel_tol=1e-9, abs_tol=1e-9 + ), f"{field} mismatch: expected {expected_value}, got {actual_value}" + else: + assert ( + actual_value == expected_value + ), f"{field} mismatch: expected {expected_value}, got {actual_value}" + + +def _assert_schema_values( + vector_info: Dict[str, Any], + expected: Dict[str, Any], + spann_active: bool, +) -> None: + assert vector_info["space"] == expected["space"] + if spann_active: + spann_cfg = cast(Optional[Dict[str, Any]], vector_info["spann"]) + assert spann_cfg is not None + for field in SPANN_FIELDS: + if field == "space": + continue + actual_value = spann_cfg.get(field) + expected_value = expected[field] + # Use approximate equality for floating-point values + if isinstance(actual_value, float) and isinstance(expected_value, float): + assert math.isclose( + actual_value, expected_value, rel_tol=1e-9, abs_tol=1e-9 + ), f"{field} mismatch: expected {expected_value}, got {actual_value}" + else: + assert ( + actual_value == expected_value + ), f"{field} mismatch: expected {expected_value}, got {actual_value}" + else: + hnsw_cfg = cast(Optional[Dict[str, Any]], vector_info["hnsw"]) + assert hnsw_cfg is not None + for field in HNSW_FIELDS: + if field == "space": + continue + actual_value = hnsw_cfg.get(field) + expected_value = expected[field] + # Use approximate equality for floating-point values + if isinstance(actual_value, float) and isinstance(expected_value, float): + assert math.isclose( + actual_value, expected_value, rel_tol=1e-9, abs_tol=1e-9 + ), f"{field} mismatch: expected {expected_value}, got {actual_value}" + else: + assert ( + actual_value == expected_value + ), f"{field} mismatch: expected {expected_value}, got {actual_value}" + + +def _get_default_schema_indexes() -> Dict[str, Dict[str, Any]]: + """ + Get expected index states for default schema (when schema=None). + Based on Schema._initialize_defaults() and _initialize_keys(). + """ + return { + "defaults": { + "string_inverted": {"enabled": True}, + "int_inverted": {"enabled": True}, + "float_inverted": {"enabled": True}, + "bool_inverted": {"enabled": True}, + "sparse_vector": {"enabled": False}, + "fts_index": {"enabled": False}, + "vector_index": {"enabled": False}, + }, + "#document": { + "string_inverted": {"enabled": False}, + "fts_index": {"enabled": True}, + }, + "#embedding": { + "vector_index": {"enabled": True}, + }, + } + + +def _extract_expected_schema_indexes( + schema: Schema, +) -> Dict[str, Dict[str, Any]]: + """ + Extract expected index states from input schema. + Returns a dict mapping key -> index_type -> enabled/config info. + """ + expected: Dict[str, Dict[str, Any]] = {} + + # Check defaults + if schema.defaults.string and schema.defaults.string.string_inverted_index: + if "defaults" not in expected: + expected["defaults"] = {} + expected["defaults"]["string_inverted"] = { + "enabled": schema.defaults.string.string_inverted_index.enabled, + } + + if schema.defaults.int_value and schema.defaults.int_value.int_inverted_index: + if "defaults" not in expected: + expected["defaults"] = {} + expected["defaults"]["int_inverted"] = { + "enabled": schema.defaults.int_value.int_inverted_index.enabled, + } + + if schema.defaults.float_value and schema.defaults.float_value.float_inverted_index: + if "defaults" not in expected: + expected["defaults"] = {} + expected["defaults"]["float_inverted"] = { + "enabled": schema.defaults.float_value.float_inverted_index.enabled, + } + + if schema.defaults.boolean and schema.defaults.boolean.bool_inverted_index: + if "defaults" not in expected: + expected["defaults"] = {} + expected["defaults"]["bool_inverted"] = { + "enabled": schema.defaults.boolean.bool_inverted_index.enabled, + } + + if ( + schema.defaults.sparse_vector + and schema.defaults.sparse_vector.sparse_vector_index + ): + if "defaults" not in expected: + expected["defaults"] = {} + expected["defaults"]["sparse_vector"] = { + "enabled": schema.defaults.sparse_vector.sparse_vector_index.enabled, + "config": schema.defaults.sparse_vector.sparse_vector_index.config, + } + + # Check per-key indexes + for key, value_types in schema.keys.items(): + if key in (EMBEDDING_KEY, "#document"): + # Skip special keys - they're handled by vector index test + continue + + key_expected: Dict[str, Any] = {} + + if value_types.string and value_types.string.string_inverted_index: + key_expected["string_inverted"] = { + "enabled": value_types.string.string_inverted_index.enabled, + } + + if value_types.int_value and value_types.int_value.int_inverted_index: + key_expected["int_inverted"] = { + "enabled": value_types.int_value.int_inverted_index.enabled, + } + + if value_types.float_value and value_types.float_value.float_inverted_index: + key_expected["float_inverted"] = { + "enabled": value_types.float_value.float_inverted_index.enabled, + } + + if value_types.boolean and value_types.boolean.bool_inverted_index: + key_expected["bool_inverted"] = { + "enabled": value_types.boolean.bool_inverted_index.enabled, + } + + if value_types.sparse_vector and value_types.sparse_vector.sparse_vector_index: + key_expected["sparse_vector"] = { + "enabled": value_types.sparse_vector.sparse_vector_index.enabled, + "config": value_types.sparse_vector.sparse_vector_index.config, + } + + if key_expected: + expected[key] = key_expected + + return expected + + +def _assert_schema_indexes( + actual_schema: Schema, + expected_indexes: Dict[str, Dict[str, Any]], +) -> None: + """Assert that the actual schema matches expected index states.""" + + # Check defaults + if "defaults" in expected_indexes: + defaults_expected = expected_indexes["defaults"] + defaults_actual = actual_schema.defaults + + if "string_inverted" in defaults_expected: + expected_enabled = defaults_expected["string_inverted"]["enabled"] + actual_string = defaults_actual.string + if actual_string and actual_string.string_inverted_index: + assert ( + actual_string.string_inverted_index.enabled == expected_enabled + ), f"defaults string_inverted enabled mismatch: expected {expected_enabled}, got {actual_string.string_inverted_index.enabled}" + else: + # If not explicitly set, defaults should be enabled + assert expected_enabled, "defaults string_inverted should be enabled" + + if "int_inverted" in defaults_expected: + expected_enabled = defaults_expected["int_inverted"]["enabled"] + actual_int = defaults_actual.int_value + if actual_int and actual_int.int_inverted_index: + assert ( + actual_int.int_inverted_index.enabled == expected_enabled + ), f"defaults int_inverted enabled mismatch: expected {expected_enabled}, got {actual_int.int_inverted_index.enabled}" + else: + assert expected_enabled, "defaults int_inverted should be enabled" + + if "float_inverted" in defaults_expected: + expected_enabled = defaults_expected["float_inverted"]["enabled"] + actual_float = defaults_actual.float_value + if actual_float and actual_float.float_inverted_index: + assert ( + actual_float.float_inverted_index.enabled == expected_enabled + ), f"defaults float_inverted enabled mismatch: expected {expected_enabled}, got {actual_float.float_inverted_index.enabled}" + else: + assert expected_enabled, "defaults float_inverted should be enabled" + + if "bool_inverted" in defaults_expected: + expected_enabled = defaults_expected["bool_inverted"]["enabled"] + actual_bool = defaults_actual.boolean + if actual_bool and actual_bool.bool_inverted_index: + assert ( + actual_bool.bool_inverted_index.enabled == expected_enabled + ), f"defaults bool_inverted enabled mismatch: expected {expected_enabled}, got {actual_bool.bool_inverted_index.enabled}" + else: + assert expected_enabled, "defaults bool_inverted should be enabled" + + if "sparse_vector" in defaults_expected: + expected_enabled = defaults_expected["sparse_vector"]["enabled"] + actual_sparse = defaults_actual.sparse_vector + assert actual_sparse is not None, "defaults sparse_vector should exist" + assert ( + actual_sparse.sparse_vector_index is not None + ), "defaults sparse_vector_index should exist" + assert ( + actual_sparse.sparse_vector_index.enabled == expected_enabled + ), f"defaults sparse_vector enabled mismatch: expected {expected_enabled}, got {actual_sparse.sparse_vector_index.enabled}" + # Validate config fields if config is provided in expected + if "config" in defaults_expected["sparse_vector"]: + expected_config = defaults_expected["sparse_vector"]["config"] + actual_config = actual_sparse.sparse_vector_index.config + if expected_config.bm25 is not None: + assert ( + actual_config.bm25 == expected_config.bm25 + ), f"defaults sparse_vector bm25 mismatch: expected {expected_config.bm25}, got {actual_config.bm25}" + if expected_config.source_key is not None: + assert ( + actual_config.source_key == expected_config.source_key + ), f"defaults sparse_vector source_key mismatch: expected {expected_config.source_key}, got {actual_config.source_key}" + + if "fts_index" in defaults_expected: + expected_enabled = defaults_expected["fts_index"]["enabled"] + actual_string = defaults_actual.string + assert actual_string is not None, "defaults string should exist" + assert ( + actual_string.fts_index is not None + ), "defaults fts_index should exist" + assert ( + actual_string.fts_index.enabled == expected_enabled + ), f"defaults fts_index enabled mismatch: expected {expected_enabled}, got {actual_string.fts_index.enabled}" + + if "vector_index" in defaults_expected: + expected_enabled = defaults_expected["vector_index"]["enabled"] + actual_float_list = defaults_actual.float_list + assert actual_float_list is not None, "defaults float_list should exist" + assert ( + actual_float_list.vector_index is not None + ), "defaults vector_index should exist" + assert ( + actual_float_list.vector_index.enabled == expected_enabled + ), f"defaults vector_index enabled mismatch: expected {expected_enabled}, got {actual_float_list.vector_index.enabled}" + + # Check per-key indexes + for key, key_expected in expected_indexes.items(): + if key == "defaults": + continue + + assert key in actual_schema.keys, f"Expected key '{key}' not found in schema" + actual_value_types = actual_schema.keys[key] + + if "string_inverted" in key_expected: + expected_enabled = key_expected["string_inverted"]["enabled"] + actual_string = actual_value_types.string + assert actual_string is not None, f"Key '{key}' string should exist" + assert ( + actual_string.string_inverted_index is not None + ), f"Key '{key}' string_inverted_index should exist" + assert ( + actual_string.string_inverted_index.enabled == expected_enabled + ), f"Key '{key}' string_inverted enabled mismatch: expected {expected_enabled}, got {actual_string.string_inverted_index.enabled}" + + if "int_inverted" in key_expected: + expected_enabled = key_expected["int_inverted"]["enabled"] + actual_int = actual_value_types.int_value + assert actual_int is not None, f"Key '{key}' int_value should exist" + assert ( + actual_int.int_inverted_index is not None + ), f"Key '{key}' int_inverted_index should exist" + assert ( + actual_int.int_inverted_index.enabled == expected_enabled + ), f"Key '{key}' int_inverted enabled mismatch: expected {expected_enabled}, got {actual_int.int_inverted_index.enabled}" + + if "float_inverted" in key_expected: + expected_enabled = key_expected["float_inverted"]["enabled"] + actual_float = actual_value_types.float_value + assert actual_float is not None, f"Key '{key}' float_value should exist" + assert ( + actual_float.float_inverted_index is not None + ), f"Key '{key}' float_inverted_index should exist" + assert ( + actual_float.float_inverted_index.enabled == expected_enabled + ), f"Key '{key}' float_inverted enabled mismatch: expected {expected_enabled}, got {actual_float.float_inverted_index.enabled}" + + if "bool_inverted" in key_expected: + expected_enabled = key_expected["bool_inverted"]["enabled"] + actual_bool = actual_value_types.boolean + assert actual_bool is not None, f"Key '{key}' boolean should exist" + assert ( + actual_bool.bool_inverted_index is not None + ), f"Key '{key}' bool_inverted_index should exist" + assert ( + actual_bool.bool_inverted_index.enabled == expected_enabled + ), f"Key '{key}' bool_inverted enabled mismatch: expected {expected_enabled}, got {actual_bool.bool_inverted_index.enabled}" + + if "sparse_vector" in key_expected: + expected_enabled = key_expected["sparse_vector"]["enabled"] + expected_config = key_expected["sparse_vector"]["config"] + actual_sparse = actual_value_types.sparse_vector + assert actual_sparse is not None, f"Key '{key}' sparse_vector should exist" + assert ( + actual_sparse.sparse_vector_index is not None + ), f"Key '{key}' sparse_vector_index should exist" + assert ( + actual_sparse.sparse_vector_index.enabled == expected_enabled + ), f"Key '{key}' sparse_vector enabled mismatch: expected {expected_enabled}, got {actual_sparse.sparse_vector_index.enabled}" + # Validate config fields match + actual_config = actual_sparse.sparse_vector_index.config + if expected_config.bm25 is not None: + assert ( + actual_config.bm25 == expected_config.bm25 + ), f"Key '{key}' sparse_vector bm25 mismatch: expected {expected_config.bm25}, got {actual_config.bm25}" + if expected_config.source_key is not None: + assert ( + actual_config.source_key == expected_config.source_key + ), f"Key '{key}' sparse_vector source_key mismatch: expected {expected_config.source_key}, got {actual_config.source_key}" + + if "fts_index" in key_expected: + expected_enabled = key_expected["fts_index"]["enabled"] + actual_string = actual_value_types.string + assert actual_string is not None, f"Key '{key}' string should exist" + assert ( + actual_string.fts_index is not None + ), f"Key '{key}' fts_index should exist" + assert ( + actual_string.fts_index.enabled == expected_enabled + ), f"Key '{key}' fts_index enabled mismatch: expected {expected_enabled}, got {actual_string.fts_index.enabled}" + + if "vector_index" in key_expected: + expected_enabled = key_expected["vector_index"]["enabled"] + actual_float_list = actual_value_types.float_list + assert actual_float_list is not None, f"Key '{key}' float_list should exist" + assert ( + actual_float_list.vector_index is not None + ), f"Key '{key}' vector_index should exist" + assert ( + actual_float_list.vector_index.enabled == expected_enabled + ), f"Key '{key}' vector_index enabled mismatch: expected {expected_enabled}, got {actual_float_list.vector_index.enabled}" + + +@given( + name=strategies.collection_name(), + optional_fields=strategies.metadata_configuration_schema_strategy(), +) +def test_vector_index_configuration_create_collection( + client: ClientAPI, + name: str, + optional_fields: strategies.CollectionInputCombination, +) -> None: + metadata = optional_fields.metadata + configuration = optional_fields.configuration + schema = optional_fields.schema + + reset(client) + collection = client.create_collection( + name=name, + metadata=metadata, + configuration=configuration, + schema=schema, + ) + + if metadata is None: + assert collection.metadata in (None, {}) + else: + check_metadata(metadata, collection.metadata) + + coll_config = collection.configuration + spann_active = not is_spann_disabled_mode + active_key = "spann" if spann_active else "hnsw" + inactive_key = "hnsw" if spann_active else "spann" + + active_block = coll_config.get(active_key) + inactive_block = coll_config.get(inactive_key) + + assert active_block is not None, f"{active_key} configuration missing" + assert inactive_block in ( + None, + {}, + ), f"{inactive_key} configuration should be absent" + + expected = _compute_expected_config( + spann_active=spann_active, + metadata=metadata, + configuration=configuration, + schema_vector_index_config=optional_fields.schema_vector_info, + ) + + _assert_config_values(cast(Dict[str, Any], active_block), expected, spann_active) + + # Check embedding function name if one was provided + if configuration and configuration.get("embedding_function") is not None: + ef = configuration["embedding_function"] + if ef is not None: + coll_ef = coll_config.get("embedding_function") + if coll_ef is not None: + ef_config = coll_ef.get_config() + if ef_config and ef_config.get("type") == "known": + assert hasattr( + ef, "name" + ), "embedding function should have name method" + assert ef_config.get("name") == ef.name(), ( + f"embedding function name mismatch: " + f"expected {ef.name()}, got {ef_config.get('name')}" + ) + + schema_result = collection.schema + assert schema_result is not None + defaults_cfg, embedding_cfg = _extract_vector_configs_from_schema(schema_result) + + if spann_active: + assert defaults_cfg["hnsw"] is None + assert embedding_cfg["hnsw"] is None + assert defaults_cfg["spann"] is not None + assert embedding_cfg["spann"] is not None + else: + assert defaults_cfg["spann"] is None + assert embedding_cfg["spann"] is None + assert defaults_cfg["hnsw"] is not None + assert embedding_cfg["hnsw"] is not None + + _assert_schema_values(defaults_cfg, expected, spann_active) + _assert_schema_values(embedding_cfg, expected, spann_active) + + # Check embedding function name in schema if one was provided + if configuration and configuration.get("embedding_function") is not None: + ef = configuration["embedding_function"] + if ef is not None: + # Check defaults vector index + defaults_ef = schema_result.defaults.float_list.vector_index.config.embedding_function # type: ignore[union-attr] + if defaults_ef is not None and hasattr(defaults_ef, "name"): + assert defaults_ef.name() == ef.name(), ( + f"defaults embedding function name mismatch: " + f"expected {ef.name()}, got {defaults_ef.name()}" + ) + # Check embedding key vector index + embedding_ef = schema_result.keys[EMBEDDING_KEY].float_list.vector_index.config.embedding_function # type: ignore[union-attr] + if embedding_ef is not None and hasattr(embedding_ef, "name"): + assert embedding_ef.name() == ef.name(), ( + f"embedding key embedding function name mismatch: " + f"expected {ef.name()}, got {embedding_ef.name()}" + ) + + +@given( + name=strategies.collection_name(), + schema=strategies.schema_strategy(), +) +def test_schema_create_and_get_collection( + client: ClientAPI, + name: str, + schema: Optional[Schema], +) -> None: + """ + Test that schema-only components (inverted indexes, sparse vector indexes) + are correctly created and persisted when creating a collection. + """ + reset(client) + + if schema is None: + expected_indexes = _get_default_schema_indexes() + else: + expected_indexes = _extract_expected_schema_indexes(schema) + + collection = client.create_collection(name=name, schema=schema) + + # Get the returned schema + schema_result = collection.schema + assert schema_result is not None, "Schema should not be None" + + _assert_schema_indexes(schema_result, expected_indexes) + + collection = client.get_collection(name) + schema_result = collection.schema + assert schema_result is not None, "Schema should not be None" + _assert_schema_indexes(schema_result, expected_indexes)