Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llama_stack/apis/vector_stores/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ class VectorStore(Resource):
:param type: Type of resource, always 'vector_store' for vector stores
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
:param distance_metric: Distance metric for vector similarity calculations (e.g., 'COSINE', 'L2', 'INNER_PRODUCT')
"""

type: Literal[ResourceType.vector_store] = ResourceType.vector_store

embedding_model: str
embedding_dimension: int
vector_store_name: str | None = None
distance_metric: str | None = None

@property
def vector_store_id(self) -> str:
Expand All @@ -42,10 +44,12 @@ class VectorStoreInput(BaseModel):
:param embedding_model: Name of the embedding model to use for vector generation
:param embedding_dimension: Dimension of the embedding vectors
:param provider_vector_store_id: (Optional) Provider-specific identifier for the vector store
:param distance_metric: (Optional) Distance metric for vector similarity calculations
"""

vector_store_id: str
embedding_model: str
embedding_dimension: int
provider_id: str | None = None
provider_vector_store_id: str | None = None
distance_metric: str | None = None
4 changes: 4 additions & 0 deletions llama_stack/core/routers/vector_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ async def openai_create_vector_store(
embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id")
distance_metric = extra.get("distance_metric")

# Use default embedding model if not specified
if (
Expand Down Expand Up @@ -154,6 +155,7 @@ async def openai_create_vector_store(
provider_id=provider_id,
provider_vector_store_id=vector_store_id,
vector_store_name=params.name,
distance_metric=distance_metric,
)
provider = await self.routing_table.get_provider_impl(registered_vector_store.identifier)

Expand All @@ -162,6 +164,8 @@ async def openai_create_vector_store(
params.model_extra = {}
params.model_extra["provider_vector_store_id"] = registered_vector_store.provider_resource_id
params.model_extra["provider_id"] = registered_vector_store.provider_id
if distance_metric is not None:
params.model_extra["distance_metric"] = distance_metric
if embedding_model is not None:
params.model_extra["embedding_model"] = embedding_model
if embedding_dimension is not None:
Expand Down
2 changes: 2 additions & 0 deletions llama_stack/core/routing_tables/vector_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def register_vector_store(
provider_id: str | None = None,
provider_vector_store_id: str | None = None,
vector_store_name: str | None = None,
distance_metric: str | None = None,
) -> Any:
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
Expand All @@ -73,6 +74,7 @@ async def register_vector_store(
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
vector_store_name=vector_store_name,
distance_metric=distance_metric,
)
await self.register_object(vector_store)
return vector_store
Expand Down
33 changes: 29 additions & 4 deletions llama_stack/providers/inline/vector_io/faiss/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@


class FaissIndex(EmbeddingIndex):
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
def __init__(
self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None, distance_metric: str = "L2"
):
self._check_distance_metric_support(distance_metric)
self.distance_metric = distance_metric
self.index = faiss.IndexFlatL2(dimension)
self.chunk_by_index: dict[int, Chunk] = {}
self.kvstore = kvstore
Expand All @@ -51,8 +55,10 @@ def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str
self.chunk_ids: list[Any] = []

@classmethod
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
instance = cls(dimension, kvstore, bank_id)
async def create(
cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None, distance_metric: str = "L2"
):
instance = cls(dimension, kvstore, bank_id, distance_metric)
await instance.initialize()
return instance

Expand Down Expand Up @@ -175,6 +181,22 @@ async def query_hybrid(
"Hybrid search is not supported - underlying DB FAISS does not support this search mode"
)

def _check_distance_metric_support(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by FAISS.

Args:
distance_metric: The distance metric to check

Raises:
NotImplementedError: If the distance metric is not supported yet
"""
if distance_metric != "L2":
# TODO: Implement support for other distance metrics in FAISS
raise NotImplementedError(
f"Distance metric '{distance_metric}' is not yet supported by the FAISS provider. "
f"Currently only 'L2' is supported. Please use 'L2' or switch to a different provider."
)


class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
Expand Down Expand Up @@ -229,9 +251,12 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
await self.kvstore.set(key=key, value=vector_store.model_dump_json())

# Store in cache
distance_metric = vector_store.distance_metric or "L2"
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store=vector_store,
index=await FaissIndex.create(vector_store.embedding_dimension, self.kvstore, vector_store.identifier),
index=await FaissIndex.create(
vector_store.embedding_dimension, self.kvstore, vector_store.identifier, distance_metric
),
inference_api=self.inference_api,
)

Expand Down
34 changes: 30 additions & 4 deletions llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,27 @@ class SQLiteVecIndex(EmbeddingIndex):
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
"""

def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None):
def __init__(
self,
dimension: int,
db_path: str,
bank_id: str,
kvstore: KVStore | None = None,
distance_metric: str = "COSINE",
):
self.dimension = dimension
self.db_path = db_path
self.bank_id = bank_id
self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}")
self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}")
self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}")
self.kvstore = kvstore
self._check_distance_metric_support(distance_metric)
self.distance_metric = distance_metric

@classmethod
async def create(cls, dimension: int, db_path: str, bank_id: str):
instance = cls(dimension, db_path, bank_id)
async def create(cls, dimension: int, db_path: str, bank_id: str, distance_metric: str = "COSINE"):
instance = cls(dimension, db_path, bank_id, distance_metric=distance_metric)
await instance.initialize()
return instance

Expand Down Expand Up @@ -373,6 +382,22 @@ def _delete_chunks():

await asyncio.to_thread(_delete_chunks)

def _check_distance_metric_support(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by SQLite-vec.

Args:
distance_metric: The distance metric to check

Raises:
NotImplementedError: If the distance metric is not supported yet
"""
if distance_metric != "COSINE":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# TODO: Implement support for other distance metrics in SQLite-vec
raise NotImplementedError(
f"Distance metric '{distance_metric}' is not yet supported by the SQLite-vec provider. "
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
)


class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
"""
Expand Down Expand Up @@ -412,8 +437,9 @@ async def list_vector_stores(self) -> list[VectorStore]:
return [v.vector_store for v in self.cache.values()]

async def register_vector_store(self, vector_store: VectorStore) -> None:
distance_metric = vector_store.distance_metric or "COSINE"
index = await SQLiteVecIndex.create(
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier
vector_store.embedding_dimension, self.config.db_path, vector_store.identifier, distance_metric
)
self.cache[vector_store.identifier] = VectorStoreWithIndex(vector_store, index, self.inference_api)

Expand Down
25 changes: 23 additions & 2 deletions llama_stack/providers/remote/vector_io/chroma/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ async def maybe_await(result):


class ChromaIndex(EmbeddingIndex):
def __init__(self, client: ChromaClientType, collection, kvstore: KVStore | None = None):
def __init__(
self, client: ChromaClientType, collection, kvstore: KVStore | None = None, distance_metric: str = "COSINE"
):
self.client = client
self.collection = collection
self.kvstore = kvstore
self._check_distance_metric_support(distance_metric)
self.distance_metric = distance_metric

async def initialize(self):
pass
Expand Down Expand Up @@ -102,6 +106,22 @@ async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> No
ids = [f"{chunk.document_id}:{chunk.chunk_id}" for chunk in chunks_for_deletion]
await maybe_await(self.collection.delete(ids=ids))

def _check_distance_metric_support(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by Chroma.

Args:
distance_metric: The distance metric to check

Raises:
NotImplementedError: If the distance metric is not supported yet
"""
if distance_metric != "COSINE":
# TODO: Implement support for other distance metrics in Chroma
raise NotImplementedError(
f"Distance metric '{distance_metric}' is not yet supported by the Chroma provider. "
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
)

async def query_hybrid(
self,
embedding: NDArray,
Expand Down Expand Up @@ -157,8 +177,9 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
name=vector_store.identifier, metadata={"vector_store": vector_store.model_dump_json()}
)
)
distance_metric = vector_store.distance_metric or "COSINE"
self.cache[vector_store.identifier] = VectorStoreWithIndex(
vector_store, ChromaIndex(self.client, collection), self.inference_api
vector_store, ChromaIndex(self.client, collection, distance_metric=distance_metric), self.inference_api
)

async def unregister_vector_store(self, vector_store_id: str) -> None:
Expand Down
33 changes: 31 additions & 2 deletions llama_stack/providers/remote/vector_io/milvus/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@

class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: MilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
distance_metric: str = "COSINE",
):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
self.consistency_level = consistency_level
self.kvstore = kvstore
self._check_distance_metric_support(distance_metric)
self.distance_metric = distance_metric

async def initialize(self):
# MilvusIndex does not require explicit initialization
Expand Down Expand Up @@ -260,6 +267,22 @@ async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> No
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise

def _check_distance_metric_support(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by Milvus.

Args:
distance_metric: The distance metric to check

Raises:
NotImplementedError: If the distance metric is not supported yet
"""
if distance_metric != "COSINE":
# TODO: Implement support for other distance metrics in Milvus
raise NotImplementedError(
f"Distance metric '{distance_metric}' is not yet supported by the Milvus provider. "
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
)


class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
Expand Down Expand Up @@ -316,9 +339,15 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
distance_metric = vector_store.distance_metric or "COSINE"
index = VectorStoreWithIndex(
vector_store=vector_store,
index=MilvusIndex(self.client, vector_store.identifier, consistency_level=consistency_level),
index=MilvusIndex(
self.client,
vector_store.identifier,
consistency_level=consistency_level,
distance_metric=distance_metric,
),
inference_api=self.inference_api,
)

Expand Down
12 changes: 10 additions & 2 deletions llama_stack/providers/remote/vector_io/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,13 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
upsert_models(self.conn, [(vector_store.identifier, vector_store)])

# Create and cache the PGVector index table for the vector DB
distance_metric = vector_store.distance_metric or "COSINE" # Default to COSINE if not specified
pgvector_index = PGVectorIndex(
vector_store=vector_store, dimension=vector_store.embedding_dimension, conn=self.conn, kvstore=self.kvstore
vector_store=vector_store,
dimension=vector_store.embedding_dimension,
conn=self.conn,
kvstore=self.kvstore,
distance_metric=distance_metric,
)
await pgvector_index.initialize()
index = VectorStoreWithIndex(vector_store, index=pgvector_index, inference_api=self.inference_api)
Expand Down Expand Up @@ -420,7 +425,10 @@ async def _get_and_cache_vector_store_index(self, vector_store_id: str) -> Vecto
if not vector_store:
raise VectorStoreNotFoundError(vector_store_id)

index = PGVectorIndex(vector_store, vector_store.embedding_dimension, self.conn)
distance_metric = vector_store.distance_metric or "COSINE" # Default to COSINE if not specified
index = PGVectorIndex(
vector_store, vector_store.embedding_dimension, self.conn, distance_metric=distance_metric
)
await index.initialize()
self.cache[vector_store_id] = VectorStoreWithIndex(vector_store, index, self.inference_api)
return self.cache[vector_store_id]
Expand Down
23 changes: 21 additions & 2 deletions llama_stack/providers/remote/vector_io/qdrant/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,11 @@ def convert_id(_id: str) -> str:


class QdrantIndex(EmbeddingIndex):
def __init__(self, client: AsyncQdrantClient, collection_name: str):
def __init__(self, client: AsyncQdrantClient, collection_name: str, distance_metric: str = "COSINE"):
self.client = client
self.collection_name = collection_name
self._check_distance_metric_support(distance_metric)
self.distance_metric = distance_metric

async def initialize(self) -> None:
# Qdrant collections are created on-demand in add_chunks
Expand Down Expand Up @@ -144,6 +146,22 @@ async def query_hybrid(
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)

def _check_distance_metric_support(self, distance_metric: str) -> None:
"""Check if the distance metric is supported by Qdrant.

Args:
distance_metric: The distance metric to check

Raises:
NotImplementedError: If the distance metric is not supported yet
"""
if distance_metric != "COSINE":
# TODO: Implement support for other distance metrics in Qdrant
raise NotImplementedError(
f"Distance metric '{distance_metric}' is not yet supported by the Qdrant provider. "
f"Currently only 'COSINE' is supported. Please use 'COSINE' or switch to a different provider."
)


class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
def __init__(
Expand Down Expand Up @@ -187,9 +205,10 @@ async def register_vector_store(self, vector_store: VectorStore) -> None:
key = f"{VECTOR_DBS_PREFIX}{vector_store.identifier}"
await self.kvstore.set(key=key, value=vector_store.model_dump_json())

distance_metric = vector_store.distance_metric or "COSINE"
index = VectorStoreWithIndex(
vector_store=vector_store,
index=QdrantIndex(self.client, vector_store.identifier),
index=QdrantIndex(self.client, vector_store.identifier, distance_metric=distance_metric),
inference_api=self.inference_api,
)

Expand Down
Loading
Loading