diff --git a/.github/workflows/ci-unit-tests.yml b/.github/workflows/ci-unit-tests.yml index a23c5e227..b6ce09026 100644 --- a/.github/workflows/ci-unit-tests.yml +++ b/.github/workflows/ci-unit-tests.yml @@ -86,6 +86,9 @@ jobs: - name: "Setup: Python 3.11" uses: ./.github/actions/setup-python + - name: "Type check (colbert)" + run: tox -e type -c libs/colbert && rm -rf libs/colbert/.tox + - name: "Type check (knowledge-graph)" run: tox -e type -c libs/knowledge-graph && rm -rf libs/knowledge-graph/.tox diff --git a/libs/colbert/pyproject.toml b/libs/colbert/pyproject.toml index 05707f977..7c47192f7 100644 --- a/libs/colbert/pyproject.toml +++ b/libs/colbert/pyproject.toml @@ -21,6 +21,8 @@ pydantic = "^2.7.1" # Remove when we upgrade to pytorch 2.4 setuptools = { version = ">=70", python = ">=3.12" } +[tool.poetry.group.dev.dependencies] +mypy = "^1.11.0" [tool.poetry.group.test.dependencies] ragstack-ai-tests-utils = { path = "../tests-utils", develop = true } @@ -28,3 +30,20 @@ pytest-asyncio = "^0.23.6" [tool.pytest.ini_options] asyncio_mode = "auto" + +[tool.mypy] +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_calls = true +disallow_untyped_decorators = true +disallow_untyped_defs = true +follow_imports = "normal" +ignore_missing_imports = true +no_implicit_reexport = true +show_error_codes = true +show_error_context = true +strict_equality = true +strict_optional = true +warn_redundant_casts = true +warn_return_any = true +warn_unused_ignores = true diff --git a/libs/colbert/ragstack_colbert/base_database.py b/libs/colbert/ragstack_colbert/base_database.py index 3371a26a7..b5bb7b2c5 100644 --- a/libs/colbert/ragstack_colbert/base_database.py +++ b/libs/colbert/ragstack_colbert/base_database.py @@ -6,7 +6,7 @@ """ from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import List, Tuple from .objects import Chunk, Vector @@ -48,13 +48,13 @@ def delete_chunks(self, doc_ids: List[str]) -> bool: @abstractmethod async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 + self, chunks: List[Chunk], concurrent_inserts: int = 100 ) -> List[Tuple[str, int]]: """Stores a list of embedded text chunks in the vector store. Args: - chunks (List[Chunk]): A list of `Chunk` instances to be stored. - concurrent_inserts (Optional[int]): How many concurrent inserts to make to + chunks: A list of `Chunk` instances to be stored. + concurrent_inserts: How many concurrent inserts to make to the database. Defaults to 100. Returns: @@ -63,14 +63,14 @@ async def aadd_chunks( @abstractmethod async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 + self, doc_ids: List[str], concurrent_deletes: int = 100 ) -> bool: """Deletes chunks from the vector store based on their document id. Args: - doc_ids (List[str]): A list of document identifiers specifying the chunks + doc_ids: A list of document identifiers specifying the chunks to be deleted. - concurrent_deletes (Optional[int]): How many concurrent deletes to make + concurrent_deletes: How many concurrent deletes to make to the database. Defaults to 100. Returns: @@ -96,7 +96,7 @@ async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk: @abstractmethod async def get_chunk_data( - self, doc_id: str, chunk_id: int, include_embedding: Optional[bool] + self, doc_id: str, chunk_id: int, include_embedding: bool = False ) -> Chunk: """Retrieve the text and metadata for a chunk. diff --git a/libs/colbert/ragstack_colbert/base_embedding_model.py b/libs/colbert/ragstack_colbert/base_embedding_model.py index afe7c54f1..0e58285d5 100644 --- a/libs/colbert/ragstack_colbert/base_embedding_model.py +++ b/libs/colbert/ragstack_colbert/base_embedding_model.py @@ -35,8 +35,8 @@ def embed_texts(self, texts: List[str]) -> List[Embedding]: def embed_query( self, query: str, - full_length_search: Optional[bool] = False, - query_maxlen: int = -1, + full_length_search: bool = False, + query_maxlen: Optional[int] = None, ) -> Embedding: """Embeds a single query text into its vector representation. @@ -44,12 +44,12 @@ def embed_query( special [mast] tokens. Args: - query (str): The query text to encode. - full_length_search (Optional[bool]): Indicates whether to encode the + query: The query text to encode. + full_length_search: Indicates whether to encode the query for a full-length search. Defaults to False. - query_maxlen (int): The fixed length for the query token embedding. - If -1, uses a dynamically calculated value. + query_maxlen: The fixed length for the query token embedding. + If None, uses a dynamically calculated value. Returns: - Embedding: A vector embedding representation of the query text + A vector embedding representation of the query text """ diff --git a/libs/colbert/ragstack_colbert/base_retriever.py b/libs/colbert/ragstack_colbert/base_retriever.py index ff50383af..a0486ffcf 100644 --- a/libs/colbert/ragstack_colbert/base_retriever.py +++ b/libs/colbert/ragstack_colbert/base_retriever.py @@ -25,7 +25,7 @@ def embedding_search( self, query_embedding: Embedding, k: Optional[int] = None, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: """Search for relevant text chunks based on a query embedding. @@ -34,16 +34,16 @@ def embedding_search( store, ranked by relevance or other metrics. Args: - query_embedding (Embedding): The query embedding to search for relevant + query_embedding: The query embedding to search for relevant text chunks. - k (Optional[int]): The number of top results to retrieve. - include_embedding (Optional[bool]): Optional (default False) flag to + k: The number of top results to retrieve. + include_embedding: Optional (default False) flag to include the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require + **kwargs: Additional parameters that implementations might require for customized retrieval operations. Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, + A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant to the query, along with its similarity score. """ @@ -54,7 +54,7 @@ async def aembedding_search( self, query_embedding: Embedding, k: Optional[int] = None, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: """Search for relevant text chunks based on a query embedding. @@ -63,16 +63,16 @@ async def aembedding_search( store, ranked by relevance or other metrics. Args: - query_embedding (Embedding): The query embedding to search for relevant + query_embedding: The query embedding to search for relevant text chunks. - k (Optional[int]): The number of top results to retrieve. - include_embedding (Optional[bool]): Optional (default False) flag to + k: The number of top results to retrieve. + include_embedding: Optional (default False) flag to include the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require + **kwargs: Additional parameters that implementations might require for customized retrieval operations. Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, + A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant to the query, along with its similarity score. """ @@ -84,7 +84,7 @@ def text_search( query_text: str, k: Optional[int] = None, query_maxlen: Optional[int] = None, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: """Search for relevant text chunks based on a query text. @@ -93,17 +93,17 @@ def text_search( store, ranked by relevance or other metrics. Args: - query_text (str): The query text to search for relevant text chunks. - k (Optional[int]): The number of top results to retrieve. - query_maxlen (Optional[int]): The maximum length of the query to consider. + query_text: The query text to search for relevant text chunks. + k: The number of top results to retrieve. + query_maxlen: The maximum length of the query to consider. If None, the maxlen will be dynamically generated. - include_embedding (Optional[bool]): Optional (default False) flag to + include_embedding: Optional (default False) flag to include the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require + **kwargs: Additional parameters that implementations might require for customized retrieval operations. Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, + A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant to the query, along with its similarity score. """ @@ -115,7 +115,7 @@ async def atext_search( query_text: str, k: Optional[int] = None, query_maxlen: Optional[int] = None, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: """Search for relevant text chunks based on a query text. @@ -124,17 +124,17 @@ async def atext_search( store, ranked by relevance or other metrics. Args: - query_text (str): The query text to search for relevant text chunks. - k (Optional[int]): The number of top results to retrieve. - query_maxlen (Optional[int]): The maximum length of the query to consider. + query_text: The query text to search for relevant text chunks. + k: The number of top results to retrieve. + query_maxlen: The maximum length of the query to consider. If None, the maxlen will be dynamically generated. - include_embedding (Optional[bool]): Optional (default False) flag to + include_embedding: Optional (default False) flag to include the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require + **kwargs: Additional parameters that implementations might require for customized retrieval operations. Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, + A list of retrieved Chunk, float Tuples, each representing a text chunk that is relevant to the query, along with its similarity score. """ diff --git a/libs/colbert/ragstack_colbert/base_vector_store.py b/libs/colbert/ragstack_colbert/base_vector_store.py index 1f95ccaf6..7f2a42c73 100644 --- a/libs/colbert/ragstack_colbert/base_vector_store.py +++ b/libs/colbert/ragstack_colbert/base_vector_store.py @@ -87,13 +87,13 @@ def delete_chunks(self, doc_ids: List[str]) -> bool: # handles LlamaIndex add @abstractmethod async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 + self, chunks: List[Chunk], concurrent_inserts: int = 100 ) -> List[Tuple[str, int]]: """Stores a list of embedded text chunks in the vector store. Args: - chunks (List[Chunk]): A list of `Chunk` instances to be stored. - concurrent_inserts (Optional[int]): How many concurrent inserts to make to + chunks: A list of `Chunk` instances to be stored. + concurrent_inserts: How many concurrent inserts to make to the database. Defaults to 100. Returns: @@ -107,7 +107,7 @@ async def aadd_texts( texts: List[str], metadatas: Optional[List[Metadata]], doc_id: Optional[str] = None, - concurrent_inserts: Optional[int] = 100, + concurrent_inserts: int = 100, ) -> List[Tuple[str, int]]: """Adds text chunks to the vector store. @@ -115,12 +115,12 @@ async def aadd_texts( store. Args: - texts (List[str]): The list of text chunks to be embedded - metadatas (Optional[List[Metadata]])): An optional list of Metadata to be + texts: The list of text chunks to be embedded + metadatas: An optional list of Metadata to be stored. If provided, these are set 1 to 1 with the texts list. - doc_id (Optional[str]): The document id associated with the texts. + doc_id: The document id associated with the texts. If not provided, it is generated. - concurrent_inserts (Optional[int]): How many concurrent inserts to make to + concurrent_inserts: How many concurrent inserts to make to the database. Defaults to 100. Returns: @@ -130,14 +130,14 @@ async def aadd_texts( # handles LangChain and LlamaIndex delete @abstractmethod async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 + self, doc_ids: List[str], concurrent_deletes: int = 100 ) -> bool: """Deletes chunks from the vector store based on their document id. Args: - doc_ids (List[str]): A list of document identifiers specifying the chunks + doc_ids: A list of document identifiers specifying the chunks to be deleted. - concurrent_deletes (Optional[int]): How many concurrent deletes to make to + concurrent_deletes: How many concurrent deletes to make to the database. Defaults to 100. Returns: diff --git a/libs/colbert/ragstack_colbert/cassandra_database.py b/libs/colbert/ragstack_colbert/cassandra_database.py index 790468227..2942cae47 100644 --- a/libs/colbert/ragstack_colbert/cassandra_database.py +++ b/libs/colbert/ragstack_colbert/cassandra_database.py @@ -9,12 +9,13 @@ import asyncio import logging from collections import defaultdict -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple import cassio from cassandra.cluster import Session from cassio.table.query import Predicate, PredicateOperator from cassio.table.tables import ClusteredMetadataVectorCassandraTable +from typing_extensions import Self, override from .base_database import BaseDatabase from .constant import DEFAULT_COLBERT_DIM @@ -39,7 +40,7 @@ class CassandraDatabase(BaseDatabase): _table: ClusteredMetadataVectorCassandraTable - def __new__(cls): # noqa: D102 + def __new__(cls) -> Self: # noqa: D102 raise ValueError( "This class cannot be instantiated directly. " "Please use the `from_astra()` or `from_session()` class methods." @@ -51,12 +52,12 @@ def from_astra( database_id: str, astra_token: str, keyspace: Optional[str] = "default_keyspace", - table_name: Optional[str] = "colbert", + table_name: str = "colbert", timeout: Optional[int] = 300, - ): + ) -> Self: """Creates a CassandraVectorStore using AstraDB connection info.""" cassio.init(token=astra_token, database_id=database_id, keyspace=keyspace) - session = cassio.config.resolve_session() + session = cassio.config.check_resolve_session() session.default_timeout = timeout return cls.from_session( @@ -68,8 +69,8 @@ def from_session( cls, session: Session, keyspace: Optional[str] = "default_keyspace", - table_name: Optional[str] = "colbert", - ): + table_name: str = "colbert", + ) -> Self: """Creates a CassandraVectorStore using an existing session.""" instance = super().__new__(cls) instance._initialize(session=session, keyspace=keyspace, table_name=table_name) # noqa: SLF001 @@ -78,17 +79,17 @@ def from_session( def _initialize( self, session: Session, - keyspace: str, + keyspace: Optional[str], table_name: str, - ): + ) -> None: """Initializes a new instance of the CassandraVectorStore. Args: - session (Session): The Cassandra session to use. - keyspace (str): The keyspace in which the table exists or will be created. - table_name (str): The name of the table to use or create for storing + session: The Cassandra session to use. + keyspace: The keyspace in which the table exists or will be created. + table_name: The name of the table to use or create for storing embeddings. - timeout (int, optional): The default timeout in seconds for Cassandra + timeout: The default timeout in seconds for Cassandra operations. Defaults to 180. """ try: @@ -113,7 +114,7 @@ def _initialize( def _log_insert_error( self, doc_id: str, chunk_id: int, embedding_id: int, exp: Exception - ): + ) -> None: if embedding_id == -1: logging.error( "issue inserting document data: %s chunk: %s: %s", doc_id, chunk_id, exp @@ -127,15 +128,8 @@ def _log_insert_error( exp, ) + @override def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: - """Stores a list of embedded text chunks in the vector store. - - Args: - chunks (List[Chunk]): A list of `Chunk` instances to be stored. - - Returns: - a list of tuples: (doc_id, chunk_id) - """ failed_chunks: List[Tuple[str, int]] = [] success_chunks: List[Tuple[str, int]] = [] @@ -157,19 +151,20 @@ def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: failed_chunks.append((doc_id, chunk_id)) continue - for embedding_id, vector in enumerate(chunk.embedding): - try: - self._table.put( - partition_id=doc_id, - row_id=(chunk_id, embedding_id), - vector=vector, - ) - except Exception as exp: # noqa: BLE001 - self._log_insert_error( - doc_id=doc_id, chunk_id=chunk_id, embedding_id=-1, exp=exp - ) - failed_chunks.append((doc_id, chunk_id)) - continue + if chunk.embedding: + for embedding_id, vector in enumerate(chunk.embedding): + try: + self._table.put( + partition_id=doc_id, + row_id=(chunk_id, embedding_id), + vector=vector, + ) + except Exception as exp: # noqa: BLE001 + self._log_insert_error( + doc_id=doc_id, chunk_id=chunk_id, embedding_id=-1, exp=exp + ) + failed_chunks.append((doc_id, chunk_id)) + continue success_chunks.append((doc_id, chunk_id)) @@ -186,7 +181,7 @@ async def _limited_put( sem: asyncio.Semaphore, doc_id: str, chunk_id: int, - embedding_id: Optional[int] = -1, + embedding_id: int = -1, text: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, vector: Optional[Vector] = None, @@ -209,22 +204,13 @@ async def _limited_put( return doc_id, chunk_id, embedding_id, e return doc_id, chunk_id, embedding_id, None + @override async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 + self, chunks: List[Chunk], concurrent_inserts: int = 100 ) -> List[Tuple[str, int]]: - """Stores a list of embedded text chunks in the vector store. - - Args: - chunks (List[Chunk]): A list of `Chunk` instances to be stored. - concurrent_inserts (Optional[int]): How many concurrent inserts to make to - the database. Defaults to 100. - - Returns: - a list of tuples: (doc_id, chunk_id) - """ semaphore = asyncio.Semaphore(concurrent_inserts) - all_tasks = [] - tasks_per_chunk = defaultdict(int) + all_tasks: List[Awaitable[Tuple[str, int, int, Optional[Exception]]]] = [] + tasks_per_chunk: Dict[Tuple[str, int], int] = defaultdict(int) for chunk in chunks: doc_id = chunk.doc_id @@ -243,27 +229,35 @@ async def aadd_chunks( ) tasks_per_chunk[(doc_id, chunk_id)] += 1 - for index, vector in enumerate(chunk.embedding): - all_tasks.append( - self._limited_put( - sem=semaphore, - doc_id=doc_id, - chunk_id=chunk_id, - embedding_id=index, - vector=vector, + if chunk.embedding: + for index, vector in enumerate(chunk.embedding): + all_tasks.append( + self._limited_put( + sem=semaphore, + doc_id=doc_id, + chunk_id=chunk_id, + embedding_id=index, + vector=vector, + ) ) - ) - tasks_per_chunk[(doc_id, chunk_id)] += 1 + tasks_per_chunk[(doc_id, chunk_id)] += 1 results = await asyncio.gather(*all_tasks, return_exceptions=True) - for doc_id, chunk_id, embedding_id, exp in results: - if exp is None: - tasks_per_chunk[(doc_id, chunk_id)] -= 1 + for result in results: + if isinstance(result, BaseException): + logging.error("issue inserting data", exc_info=result) else: - self._log_insert_error( - doc_id=doc_id, chunk_id=chunk_id, embedding_id=embedding_id, exp=exp - ) + doc_id, chunk_id, embedding_id, exp = result + if exp is None: + tasks_per_chunk[(doc_id, chunk_id)] -= 1 + else: + self._log_insert_error( + doc_id=doc_id, + chunk_id=chunk_id, + embedding_id=embedding_id, + exp=exp, + ) outputs: List[Tuple[str, int]] = [] failed_chunks: List[Tuple[str, int]] = [] @@ -282,16 +276,8 @@ async def aadd_chunks( return outputs + @override def delete_chunks(self, doc_ids: List[str]) -> bool: - """Deletes chunks from the vector store based on their document id. - - Args: - doc_ids (List[str]): A list of document identifiers specifying the chunks - to be deleted. - - Returns: - True if the all the deletes were successful. - """ failed_docs: List[str] = [] for doc_id in doc_ids: @@ -321,20 +307,10 @@ async def _limited_delete( return doc_id, e return doc_id, None + @override async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 + self, doc_ids: List[str], concurrent_deletes: int = 100 ) -> bool: - """Deletes chunks from the vector store based on their document id. - - Args: - doc_ids (List[str]): A list of document identifiers specifying the chunks - to be deleted. - concurrent_deletes (Optional[int]): How many concurrent deletes to make - to the database. Defaults to 100. - - Returns: - True if the all the deletes were successful. - """ semaphore = asyncio.Semaphore(concurrent_deletes) all_tasks = [ self._limited_delete( @@ -349,11 +325,15 @@ async def adelete_chunks( success = True failed_docs: List[str] = [] - for doc_id, exp in results: - if exp is not None: - logging.error("issue deleting document: %s: %s", doc_id, exp) - success = False - failed_docs.append(doc_id) + for result in results: + if isinstance(result, BaseException): + logging.error("issue inserting data", exc_info=result) + else: + doc_id, exp = result + if exp is not None: + logging.error("issue deleting document: %s", doc_id, exc_info=exp) + success = False + failed_docs.append(doc_id) if len(failed_docs) > 0: raise CassandraDatabaseError( @@ -363,13 +343,8 @@ async def adelete_chunks( return success + @override async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: - """Retrieves 'n' ANN results for an embedded token vector. - - Returns: - A list of Chunks with only `doc_id` and `chunk_id` set. - Fewer than 'n' results may be returned. - """ chunks: Set[Chunk] = set() # TODO: only return partition_id and row_id after cassio supports this @@ -383,12 +358,8 @@ async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: ) return list(chunks) + @override async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk: - """Retrieve the embedding data for a chunk. - - Returns: - A chunk with `doc_id`, `chunk_id`, and `embedding` set. - """ row_id = (chunk_id, Predicate(PredicateOperator.GT, -1)) rows = await self._table.aget_partition(partition_id=doc_id, row_id=row_id) @@ -396,17 +367,18 @@ async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk: return Chunk(doc_id=doc_id, chunk_id=chunk_id, embedding=embedding) + @override async def get_chunk_data( - self, doc_id: str, chunk_id: int, include_embedding: Optional[bool] = False + self, doc_id: str, chunk_id: int, include_embedding: bool = False ) -> Chunk: - """Retrieve the text and metadata for a chunk. - - Returns: - A chunk with `doc_id`, `chunk_id`, `text`, and `metadata` set. - """ row_id = (chunk_id, Predicate(PredicateOperator.EQ, -1)) row = await self._table.aget(partition_id=doc_id, row_id=row_id) + if row is None: + raise CassandraDatabaseError( + f"no chunk found for doc_id: {doc_id} chunk_id: {chunk_id}" + ) + if include_embedding is True: embedded_chunk = await self.get_chunk_embedding( doc_id=doc_id, chunk_id=chunk_id @@ -423,5 +395,6 @@ async def get_chunk_data( embedding=embedding, ) + @override def close(self) -> None: - """Cleans up any open resources.""" + pass diff --git a/libs/colbert/ragstack_colbert/colbert_embedding_model.py b/libs/colbert/ragstack_colbert/colbert_embedding_model.py index 42fcf2179..e76cb4274 100644 --- a/libs/colbert/ragstack_colbert/colbert_embedding_model.py +++ b/libs/colbert/ragstack_colbert/colbert_embedding_model.py @@ -13,6 +13,7 @@ from typing import List, Optional from colbert.infra import ColBERTConfig +from typing_extensions import override from .base_embedding_model import BaseEmbeddingModel from .constant import DEFAULT_COLBERT_MODEL @@ -37,14 +38,14 @@ class ColbertEmbeddingModel(BaseEmbeddingModel): def __init__( self, - checkpoint: Optional[str] = DEFAULT_COLBERT_MODEL, - doc_maxlen: Optional[int] = 256, - nbits: Optional[int] = 2, - kmeans_niters: Optional[int] = 4, - nranks: Optional[int] = -1, + checkpoint: str = DEFAULT_COLBERT_MODEL, + doc_maxlen: int = 256, + nbits: int = 2, + kmeans_niters: int = 4, + nranks: int = -1, query_maxlen: Optional[int] = None, - verbose: Optional[int] = 3, # 3 is the default on ColBERT checkpoint - chunk_batch_size: Optional[int] = 640, + verbose: int = 3, # 3 is the default on ColBERT checkpoint + chunk_batch_size: int = 640, ): """Initializes a new instance of the ColbertEmbeddingModel class. @@ -53,20 +54,19 @@ def __init__( tokenizer and encoder. Args: - checkpoint (Optional[str]): Path or URL to the Colbert model checkpoint. + checkpoint: Path or URL to the Colbert model checkpoint. Default is a pre-defined model. - doc_maxlen (Optional[int]): Maximum number of tokens for document chunks. + doc_maxlen: Maximum number of tokens for document chunks. Should equal the chunk_size. - nbits (Optional[int]): The number bits that each dimension encodes to. - kmeans_niters (Optional[int]): Number of iterations for k-means clustering + nbits: The number bits that each dimension encodes to. + kmeans_niters: Number of iterations for k-means clustering during quantization. - nranks (Optional[int]): Number of ranks (processors) to use for distributed + nranks: Number of ranks (processors) to use for distributed computing; -1 uses all available CPUs/GPUs. - query_maxlen (Optional[int]): Maximum length of query tokens for embedding. - verbose (Optional[int]): Verbosity level for logging. - chunk_batch_size (Optional[int]): The number of chunks to batch during + query_maxlen: Maximum length of query tokens for embedding. + verbose: Verbosity level for logging. + chunk_batch_size: The number of chunks to batch during embedding. Defaults to 640. - **kwargs: Additional keyword arguments for future extensions. """ if query_maxlen is None: query_maxlen = -1 @@ -83,16 +83,8 @@ def __init__( ) self._encoder = TextEncoder(config=colbert_config, verbose=verbose) - # implements the Abstract Class Method + @override def embed_texts(self, texts: List[str]) -> List[Embedding]: - """Embeds a list of texts into their vector embedding representations. - - Args: - texts (List[str]): A list of string texts. - - Returns: - List[Embedding]: A list of embeddings, in the order of the input list - """ chunks = [ Chunk(doc_id="dummy", chunk_id=i, text=t) for i, t in enumerate(texts) ] @@ -105,30 +97,15 @@ def embed_texts(self, texts: List[str]) -> List[Embedding]: sorted_embedded_chunks = sorted(embedded_chunks, key=lambda c: c.chunk_id) - return [c.embedding for c in sorted_embedded_chunks] + return [c.embedding or [] for c in sorted_embedded_chunks] - # implements the Abstract Class Method + @override def embed_query( self, query: str, - full_length_search: Optional[bool] = False, + full_length_search: bool = False, query_maxlen: Optional[int] = None, ) -> Embedding: - """Embeds a single query text into its vector representation. - - If the query has fewer than query_maxlen tokens it will be padded with BERT - special [mast] tokens. - - Args: - query (str): The query string to encode. - full_length_search (Optional[bool]): Indicates whether to encode the query - for a full-length search. Defaults to False. - query_maxlen (int): The fixed length for the query token embedding. - If None, uses a dynamically calculated value. - - Returns: - Embedding: A vector embedding representation of the query text - """ if query_maxlen is None: query_maxlen = -1 diff --git a/libs/colbert/ragstack_colbert/colbert_retriever.py b/libs/colbert/ragstack_colbert/colbert_retriever.py index dbf3da59d..ee3d891ab 100644 --- a/libs/colbert/ragstack_colbert/colbert_retriever.py +++ b/libs/colbert/ragstack_colbert/colbert_retriever.py @@ -25,11 +25,11 @@ from .objects import Chunk, Embedding, Vector -def all_gpus_support_fp16(is_cuda: Optional[bool] = False): +def all_gpus_support_fp16(is_cuda: bool = False) -> bool: """Check if all available GPU devices support FP16 (half-precision) operations. Returns: - bool: True if all GPUs support FP16, False otherwise. + True if all GPUs support FP16, False otherwise. """ if not is_cuda: return False @@ -58,8 +58,8 @@ def all_gpus_support_fp16(is_cuda: Optional[bool] = False): def max_similarity_torch( query_vector: Vector, chunk_embedding: Embedding, - is_cuda: Optional[bool] = False, - is_fp16: Optional[bool] = False, + is_cuda: bool = False, + is_fp16: bool = False, ) -> float: """Calculates the maximum similarity for a query vector and a chunk embedding. @@ -67,17 +67,17 @@ def max_similarity_torch( chunk embedding, leveraging PyTorch for efficient computation. Args: - query_vector (Vector): A list of float representing the query text. - chunk_embedding (Embedding): A list of Vector, each representing an chunk + query_vector: A list of float representing the query text. + chunk_embedding: A list of Vector, each representing an chunk embedding vector. - is_cuda (Optional[bool]): A flag indicating whether to use CUDA (GPU) + is_cuda: A flag indicating whether to use CUDA (GPU) for computation. Defaults to False. - is_fp16 (bool): A flag indicating whether to half-precision floating point + is_fp16: A flag indicating whether to half-precision floating point operations on CUDA (GPU). Has no effect on CPU computation. Defaults to False. Returns: - Tensor: A tensor containing the highest similarity score (dot product value) + A tensor containing the highest similarity score (dot product value) found between the query vector and any of the embedding vectors in the list. Note: @@ -108,16 +108,6 @@ def max_similarity_torch( return float(max_sim.item()) -def get_trace(e: Exception) -> str: - """Extracts the traceback information from an exception.""" - trace = "" - tb = e.__traceback__ - while tb is not None: - trace += f"\tFile: {tb.tb_frame.f_code.co_filename} Line: {tb.tb_lineno}\n" - tb = tb.tb_next - return trace - - class ColbertRetriever(BaseRetriever): """ColBERT Retriever. @@ -157,9 +147,6 @@ def __init__( self._is_cuda = torch.cuda.is_available() self._is_fp16 = all_gpus_support_fp16(self._is_cuda) - def close(self) -> None: - """Closes any open resources held by the retriever.""" - async def _query_relevant_chunks( self, query_embedding: Embedding, top_k: int ) -> Set[Chunk]: @@ -174,11 +161,10 @@ async def _query_relevant_chunks( # Process results and handle potential exceptions for result in results: - if isinstance(result, Exception): + if isinstance(result, BaseException): logging.error( - "Issue on database.get_relevant_chunks(): %s at %s", - result, - get_trace(result), + "Issue on database.get_relevant_chunks()", + exc_info=result, ) else: chunks.update(result) @@ -195,15 +181,17 @@ async def _get_chunk_embeddings(self, chunks: Set[Chunk]) -> List[Chunk]: results = await asyncio.gather(*tasks, return_exceptions=True) # Process results and handle potential exceptions + chunk_embeddings = [] for result in results: - if isinstance(result, Exception): + if isinstance(result, BaseException): logging.error( - "Issue on database.get_chunk_embeddings(): %s at %s", - result, - get_trace(result), + "Issue on database.get_chunk_embeddings()", + exc_info=result, ) + else: + chunk_embeddings.append(result) - return results + return chunk_embeddings def _score_chunks( self, query_embedding: Embedding, chunk_embeddings: List[Chunk] @@ -211,6 +199,8 @@ def _score_chunks( """Process the retrieved chunk data to calculate scores.""" chunk_scores = {} for chunk in chunk_embeddings: + if not chunk.embedding: + continue chunk_scores[chunk] = sum( max_similarity_torch( query_vector=query_vector, @@ -225,7 +215,7 @@ def _score_chunks( async def _get_chunk_data( self, chunks: List[Chunk], - include_embedding: Optional[bool] = False, + include_embedding: bool = False, ) -> List[Chunk]: """Fetches text and metadata for each chunk. @@ -244,15 +234,17 @@ async def _get_chunk_data( ] results = await asyncio.gather(*tasks, return_exceptions=True) + chunks = [] for result in results: - if isinstance(result, Exception): + if isinstance(result, BaseException): logging.error( - "Issue on database.get_chunk_data(): %s at %s", - result, - get_trace(result), + "Issue on database.get_chunk_data()", + exc_info=result, ) + else: + chunks.append(result) - return results + return chunks @override async def atext_search( @@ -260,29 +252,9 @@ async def atext_search( query_text: str, k: Optional[int] = 5, query_maxlen: Optional[int] = None, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """Searches for relevant text chunks based on a given query text. - - Retrieves a list of text chunks most relevant to the given query, - using semantic similarity as the criteria. - - Args: - query_text (str): The query text to search for relevant text chunks. - k (Optional[int]): The number of top results to retrieve. Default 5. - query_maxlen (Optional[int]): The maximum length of the query to consider. - If None, the maxlen will be dynamically generated. - include_embedding (Optional[bool]): Optional (default False) flag to include - the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require - for customized retrieval operations. - - Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, - each representing a text chunk that is relevant to the query, - along with its similarity score. - """ query_embedding = self._embedding_model.embed_query( query=query_text, query_maxlen=query_maxlen ) @@ -299,28 +271,11 @@ async def aembedding_search( self, query_embedding: Embedding, k: Optional[int] = 5, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """Searches for relevant text chunks based on a given query embedding. - - Retrieves a list of text chunks most relevant to the given query, - using semantic similarity as the criteria. - - Args: - query_embedding (Embedding): The query embedding to search for relevant - text chunks. - k (Optional[int]): The number of top results to retrieve. Default 5. - include_embedding (Optional[bool]): Optional (default False) flag to include - the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require - for customized retrieval operations. - - Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, - each representing a text chunk that is relevant to the query, - along with its similarity score. - """ + if k is None: + k = 5 top_k = max(math.floor(len(query_embedding) / 2), 16) logging.debug( "based on query length of %s tokens, retrieving %s results per " @@ -330,7 +285,7 @@ async def aembedding_search( ) # search for relevant chunks (only with `doc_id` and `chunk_id` set) - relevant_chunks: List[Chunk] = await self._query_relevant_chunks( + relevant_chunks: Set[Chunk] = await self._query_relevant_chunks( query_embedding=query_embedding, top_k=top_k ) @@ -348,7 +303,7 @@ async def aembedding_search( # only keep the top k sorted results top_k_chunks: List[Chunk] = sorted( - chunk_scores, key=chunk_scores.get, reverse=True + chunk_scores, key=lambda c: chunk_scores.get(c, 0), reverse=True )[:k] chunks: List[Chunk] = await self._get_chunk_data( @@ -363,29 +318,9 @@ def text_search( query_text: str, k: Optional[int] = 5, query_maxlen: Optional[int] = None, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """Searches for relevant text chunks based on a given query text. - - Retrieves a list of text chunks relevant to a given query from the vector - store, ranked by relevance or other metrics. - - Args: - query_text (str): The query text to search for relevant text chunks. - k (Optional[int]): The number of top results to retrieve. Default 5. - query_maxlen (Optional[int]): The maximum length of the query to consider. - If None, the maxlen will be dynamically generated. - include_embedding (Optional[bool]): Optional (default False) flag to - include the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require - for customized retrieval operations. - - Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, - each representing a text chunk that is relevant to the query, - along with its similarity score. - """ return asyncio.run( self.atext_search( query_text=query_text, @@ -401,28 +336,9 @@ def embedding_search( self, query_embedding: Embedding, k: Optional[int] = 5, - include_embedding: Optional[bool] = False, + include_embedding: bool = False, **kwargs: Any, ) -> List[Tuple[Chunk, float]]: - """Searches for relevant text chunks based on a given query embedding. - - Retrieves a list of text chunks relevant to a given query from the vector - store, ranked by relevance or other metrics. - - Args: - query_embedding (Embedding): The query embedding to search for relevant - text chunks. - k (Optional[int]): The number of top results to retrieve. Default 5. - include_embedding (Optional[bool]): Optional (default False) flag to - include the embedding vectors in the returned chunks - **kwargs (Any): Additional parameters that implementations might require - for customized retrieval operations. - - Returns: - List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples, - each representing a text chunk that is relevant to the query, - along with its similarity score. - """ return asyncio.run( self.aembedding_search( query_embedding=query_embedding, diff --git a/libs/colbert/ragstack_colbert/colbert_vector_store.py b/libs/colbert/ragstack_colbert/colbert_vector_store.py index 49c38572e..ca00465fa 100644 --- a/libs/colbert/ragstack_colbert/colbert_vector_store.py +++ b/libs/colbert/ragstack_colbert/colbert_vector_store.py @@ -10,6 +10,8 @@ import uuid from typing import List, Optional, Tuple +from typing_extensions import override + from .base_database import BaseDatabase from .base_embedding_model import BaseEmbeddingModel from .base_retriever import BaseRetriever @@ -28,7 +30,7 @@ class ColbertVectorStore(BaseVectorStore): """ _database: BaseDatabase - _embedding_model: BaseEmbeddingModel + _embedding_model: Optional[BaseEmbeddingModel] def __init__( self, @@ -38,11 +40,12 @@ def __init__( self._database = database self._embedding_model = embedding_model - def _validate_embedding_model(self): + def _validate_embedding_model(self) -> BaseEmbeddingModel: if self._embedding_model is None: raise AttributeError( "To use this method, `embedding_model` must be set on class creation." ) + return self._embedding_model def _build_chunks( self, @@ -50,7 +53,7 @@ def _build_chunks( metadatas: Optional[List[Metadata]] = None, doc_id: Optional[str] = None, ) -> List[Chunk]: - self._validate_embedding_model() + embedding_model = self._validate_embedding_model() if metadatas is not None and len(texts) != len(metadatas): raise ValueError("Length of texts and metadatas must match.") @@ -58,7 +61,7 @@ def _build_chunks( if doc_id is None: doc_id = str(uuid.uuid4()) - embeddings = self._embedding_model.embed_texts(texts=texts) + embeddings = embedding_model.embed_texts(texts=texts) chunks: List[Chunk] = [] for i, text in enumerate(texts): @@ -73,124 +76,56 @@ def _build_chunks( ) return chunks - # implements the abc method to handle LlamaIndex add + @override def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: - """Stores a list of embedded text chunks in the vector store. - - Args: - chunks (List[Chunk]): A list of `Chunk` instances to be stored. - - Returns: - a list of tuples: (doc_id, chunk_id) - """ return self._database.add_chunks(chunks=chunks) - # implements the abc method to handle LangChain add + @override def add_texts( self, texts: List[str], metadatas: Optional[List[Metadata]] = None, doc_id: Optional[str] = None, ) -> List[Tuple[str, int]]: - """Adds text chunks to the vector store. - - Embeds and stores a list of text chunks and optional metadata into the vector - store. - - Args: - texts: The list of text chunks to be embedded - metadatas: An optional list of Metadata to be stored. - If provided, these are set 1 to 1 with the texts list. - doc_id: The document id associated with the texts. - If not provided, it is generated. - - Returns: - a list of tuples: (doc_id, chunk_id) - """ chunks = self._build_chunks(texts=texts, metadatas=metadatas, doc_id=doc_id) return self._database.add_chunks(chunks=chunks) - # implements the abc method to handle LangChain and LlamaIndex delete + @override def delete_chunks(self, doc_ids: List[str]) -> bool: - """Deletes chunks from the vector store based on their document id. - - Args: - doc_ids: A list of document identifiers specifying the chunks to be deleted. - - Returns: - True if the all the deletes were successful. - """ return self._database.delete_chunks(doc_ids=doc_ids) - # implements the abc method to handle LlamaIndex add + @override async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100 + self, chunks: List[Chunk], concurrent_inserts: int = 100 ) -> List[Tuple[str, int]]: - """Stores a list of embedded text chunks in the vector store. - - Args: - chunks: A list of `Chunk` instances to be stored. - concurrent_inserts: How many concurrent inserts to make to the database. - Defaults to 100. - - Returns: - a list of tuples: (doc_id, chunk_id) - """ return await self._database.aadd_chunks( chunks=chunks, concurrent_inserts=concurrent_inserts ) - # implements the abc method to handle LangChain add + @override async def aadd_texts( self, texts: List[str], metadatas: Optional[List[Metadata]] = None, doc_id: Optional[str] = None, - concurrent_inserts: Optional[int] = 100, + concurrent_inserts: int = 100, ) -> List[Tuple[str, int]]: - """Adds text chunks to the vector store. - - Embeds and stores a list of text chunks and optional metadata into the vector - store. - - Args: - texts (List[str]): The list of text chunks to be embedded - metadatas: An optional list of Metadata to be stored. - If provided, these are set 1 to 1 with the texts list. - doc_id: The document id associated with the texts. - If not provided, it is generated. - concurrent_inserts: How many concurrent inserts to make to the database. - Defaults to 100. - - Returns: - a list of tuples: (doc_id, chunk_id) - """ chunks = self._build_chunks(texts=texts, metadatas=metadatas, doc_id=doc_id) return await self._database.aadd_chunks( chunks=chunks, concurrent_inserts=concurrent_inserts ) - # implements the abc method to handle LangChain and LlamaIndex delete + @override async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100 + self, doc_ids: List[str], concurrent_deletes: int = 100 ) -> bool: - """Deletes chunks from the vector store based on their document id. - - Args: - doc_ids: A list of document identifiers specifying the chunks to be deleted. - concurrent_deletes: How many concurrent deletes to make to the database. - Defaults to 100. - - Returns: - True if the all the deletes were successful. - """ return await self._database.adelete_chunks( doc_ids=doc_ids, concurrent_deletes=concurrent_deletes ) + @override def as_retriever(self) -> BaseRetriever: - """Gets a retriever using the vector store.""" - self._validate_embedding_model() + embedding_model = self._validate_embedding_model() return ColbertRetriever( - database=self._database, embedding_model=self._embedding_model + database=self._database, embedding_model=embedding_model ) diff --git a/libs/colbert/ragstack_colbert/objects.py b/libs/colbert/ragstack_colbert/objects.py index 150dc3e1f..baed6a1d2 100644 --- a/libs/colbert/ragstack_colbert/objects.py +++ b/libs/colbert/ragstack_colbert/objects.py @@ -57,5 +57,5 @@ def __lt__(self, other: object) -> bool: # Allow objects to be hashable - only necessary if you need to use them in sets or # as dict keys. - def __hash__(self): + def __hash__(self) -> int: return hash((self.doc_id, self.chunk_id)) diff --git a/libs/colbert/ragstack_colbert/text_encoder.py b/libs/colbert/ragstack_colbert/text_encoder.py index 5445c0fb5..7b520bad0 100644 --- a/libs/colbert/ragstack_colbert/text_encoder.py +++ b/libs/colbert/ragstack_colbert/text_encoder.py @@ -9,7 +9,7 @@ """ import logging -from typing import List, Optional +from typing import List, Optional, cast import torch from colbert.infra import ColBERTConfig @@ -107,7 +107,7 @@ def encode_chunks(self, chunks: List[Chunk], batch_size: int = 640) -> List[Chun return embedded_chunks def encode_query( - self, text: str, query_maxlen: int, full_length_search: Optional[bool] = False + self, text: str, query_maxlen: int, full_length_search: bool = False ) -> Embedding: """Encodes a query into an embedding.""" if query_maxlen < 0: @@ -127,4 +127,4 @@ def encode_query( self._checkpoint.query_tokenizer.query_maxlen = prev_query_maxlen - return query_embedding.tolist()[0] + return cast(Embedding, query_embedding.tolist()[0]) diff --git a/libs/colbert/tests/__init__.py b/libs/colbert/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/colbert/tests/integration_tests/conftest.py b/libs/colbert/tests/integration_tests/conftest.py index f0c8c68e0..f2d61325d 100644 --- a/libs/colbert/tests/integration_tests/conftest.py +++ b/libs/colbert/tests/integration_tests/conftest.py @@ -1,4 +1,5 @@ import pytest +from _pytest.fixtures import FixtureRequest from cassandra.cluster import Session from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore @@ -17,7 +18,7 @@ def astra_db() -> AstraDBTestStore: @pytest.fixture() -def session(request) -> Session: +def session(request: FixtureRequest) -> Session: test_store = request.getfixturevalue(request.param) session = test_store.create_cassandra_session() session.default_timeout = 180 diff --git a/libs/colbert/tests/integration_tests/test_database.py b/libs/colbert/tests/integration_tests/test_database.py index ccaf99067..50a4150ae 100644 --- a/libs/colbert/tests/integration_tests/test_database.py +++ b/libs/colbert/tests/integration_tests/test_database.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"]) -def test_database_sync(session: Session): +def test_database_sync(session: Session) -> None: doc_id = "earth_doc_id" chunk_0 = Chunk( @@ -43,15 +43,17 @@ def test_database_sync(session: Session): @pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"]) -async def test_database_async(session: Session): +async def test_database_async(session: Session) -> None: doc_id = "earth_doc_id" + climate_change_embedding = TestData.climate_change_embedding() + chunk_0 = Chunk( doc_id=doc_id, chunk_id=0, text=TestData.climate_change_text(), metadata={"name": "climate_change", "id": 23}, - embedding=TestData.climate_change_embedding(), + embedding=climate_change_embedding, ) chunk_1 = Chunk( @@ -73,7 +75,9 @@ async def test_database_async(session: Session): assert results[0] == (doc_id, 0) assert results[1] == (doc_id, 1) - chunks = await database.search_relevant_chunks(vector=chunk_0.embedding[5], n=2) + chunks = await database.search_relevant_chunks( + vector=climate_change_embedding[5], n=2 + ) assert len(chunks) == 1 assert chunks[0].doc_id == doc_id assert chunks[0].chunk_id == 0 diff --git a/libs/colbert/tests/integration_tests/test_embedding_retrieval.py b/libs/colbert/tests/integration_tests/test_embedding_retrieval.py index fc0199b92..e7a47ef7e 100644 --- a/libs/colbert/tests/integration_tests/test_embedding_retrieval.py +++ b/libs/colbert/tests/integration_tests/test_embedding_retrieval.py @@ -1,4 +1,5 @@ import logging +from typing import List import pytest from cassandra.cluster import Session @@ -11,7 +12,7 @@ @pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"]) -def test_embedding_cassandra_retriever(session: Session): +def test_embedding_cassandra_retriever(session: Session) -> None: narrative = TestData.marine_animals_text() # Define the desired chunk size and overlap size @@ -19,7 +20,7 @@ def test_embedding_cassandra_retriever(session: Session): overlap_size = 50 # Function to generate chunks with the specified size and overlap - def chunk_texts(text, chunk_size, overlap_size): + def chunk_texts(text: str, chunk_size: int, overlap_size: int) -> List[str]: texts = [] start = 0 end = chunk_size diff --git a/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py b/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py index 2ec4f4916..027037574 100644 --- a/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py +++ b/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py @@ -42,7 +42,7 @@ # a uility function to evaluate similarity of two embeddings at per token level -def are_they_similar(embedded_chunks: List[Embedding], tensors: List[Tensor]): +def are_they_similar(embedded_chunks: List[Embedding], tensors: List[Tensor]) -> None: n = 0 pdist = torch.nn.PairwiseDistance(p=2) for embedding in embedded_chunks: @@ -62,7 +62,7 @@ def are_they_similar(embedded_chunks: List[Embedding], tensors: List[Tensor]): assert n == len(tensors) -def test_embeddings_with_baseline(): +def test_embeddings_with_baseline() -> None: colbert = ColbertEmbeddingModel( doc_maxlen=220, nbits=2, @@ -117,7 +117,7 @@ def test_embeddings_with_baseline(): are_they_similar(embedded_chunks2, embedded_tensors) -def test_colbert_embedding_against_vanilla_impl(): +def test_colbert_embedding_against_vanilla_impl() -> None: # this is a vanilla ColBERT embedding in a list of per token embeddings # based on the just Stanford ColBERT library cf = ColBERTConfig(checkpoint="colbert-ir/colbertv2.0") @@ -134,7 +134,7 @@ def test_colbert_embedding_against_vanilla_impl(): are_they_similar(embedded_chunks, embeddings_flat) -def model_embedding(model: str): +def model_embedding(model: str) -> None: logging.info("test model compatibility %s", model) colbert_svc = ColbertEmbeddingModel( checkpoint=model, @@ -161,7 +161,7 @@ def model_embedding(model: str): assert len(embedding) == query_maxlen -def test_compatible_models(): +def test_compatible_models() -> None: # ColBERT models and Google BERT models on HF # test representive models's compatibility with this repo's ColBERT embedding # evaluation is not within this test scope @@ -177,4 +177,5 @@ def test_compatible_models(): # "google-bert/bert-base-cased", # already tested uncased ] - [model_embedding(model) for model in models] + for model in models: + model_embedding(model) diff --git a/libs/colbert/tests/unit_tests/test_colbert_embeddings.py b/libs/colbert/tests/unit_tests/test_colbert_embeddings.py index 6c1015fe2..91152f3ba 100644 --- a/libs/colbert/tests/unit_tests/test_colbert_embeddings.py +++ b/libs/colbert/tests/unit_tests/test_colbert_embeddings.py @@ -3,7 +3,7 @@ from ragstack_colbert.constant import DEFAULT_COLBERT_DIM, DEFAULT_COLBERT_MODEL -def test_colbert_token_embeddings(): +def test_colbert_token_embeddings() -> None: colbert = ColbertEmbeddingModel() texts = ["test1", "test2"] @@ -14,7 +14,7 @@ def test_colbert_token_embeddings(): assert len(embeddings[0][0]) == DEFAULT_COLBERT_DIM -def test_colbert_token_embeddings_with_params(): +def test_colbert_token_embeddings_with_params() -> None: colbert = ColbertEmbeddingModel( doc_maxlen=220, nbits=2, @@ -32,7 +32,7 @@ def test_colbert_token_embeddings_with_params(): assert len(embeddings[0][0]) == DEFAULT_COLBERT_DIM -def test_colbert_query_embeddings(): +def test_colbert_query_embeddings() -> None: colbert = ColbertEmbeddingModel() embedding = colbert.embed_query("who is the president of the united states?") diff --git a/libs/colbert/tests/unit_tests/test_colbert_retriever.py b/libs/colbert/tests/unit_tests/test_colbert_retriever.py index 6675c2f4e..277ba0409 100644 --- a/libs/colbert/tests/unit_tests/test_colbert_retriever.py +++ b/libs/colbert/tests/unit_tests/test_colbert_retriever.py @@ -3,7 +3,7 @@ from ragstack_colbert.text_encoder import calculate_query_maxlen -def test_max_similarity_torch(): +def test_max_similarity_torch() -> None: # Example query vector and embedding list query_vector = torch.tensor([1, 2, 3], dtype=torch.float32) embedding_list = [ @@ -20,7 +20,9 @@ def test_max_similarity_torch(): ) # Should be the highest # Call the function under test - max_sim = max_similarity_torch(query_vector, embedding_list) + max_sim = max_similarity_torch( + query_vector.tolist(), [embedding.tolist() for embedding in embedding_list] + ) # Check if the returned max similarity matches the expected value assert ( @@ -28,7 +30,7 @@ def test_max_similarity_torch(): ), "The max similarity does not match the expected value." -def test_query_maxlen_calculation(): +def test_query_maxlen_calculation() -> None: tokens = [["word1"], ["word2", "word3"]] assert calculate_query_maxlen(tokens) == 5 # noqa: PLR2004 diff --git a/libs/colbert/tox.ini b/libs/colbert/tox.ini index d05fe6e3b..b9281072e 100644 --- a/libs/colbert/tox.ini +++ b/libs/colbert/tox.ini @@ -1,25 +1,30 @@ [tox] min_version = 4.0 -envlist = py311 +envlist = type, unit-tests, integration-tests + +[testenv] +description = install dependencies +skip_install = true +allowlist_externals = poetry +commands_pre = + poetry env use system + poetry install [testenv:unit-tests] description = run unit tests -deps = - poetry commands = - poetry install - poetry build poetry run pytest --disable-warnings {toxinidir}/tests/unit_tests [testenv:integration-tests] description = run integration tests -deps = - poetry pass_env = ASTRA_DB_TOKEN ASTRA_DB_ID ASTRA_DB_ENV commands = - poetry install - poetry -V - poetry run pytest --disable-warnings {toxinidir}/tests/integration_tests \ No newline at end of file + poetry run pytest --disable-warnings {toxinidir}/tests/integration_tests + +[testenv:type] +description = run type checking +commands = + poetry run mypy {toxinidir}