Skip to content

Add type checking of ragstack-colbert #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ jobs:
- name: "Setup: Python 3.11"
uses: ./.github/actions/setup-python

- name: "Type check (colbert)"
run: tox -e type -c libs/colbert && rm -rf libs/colbert/.tox

- name: "Type check (knowledge-graph)"
run: tox -e type -c libs/knowledge-graph && rm -rf libs/knowledge-graph/.tox

Expand Down
19 changes: 19 additions & 0 deletions libs/colbert/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,29 @@ pydantic = "^2.7.1"
# Remove when we upgrade to pytorch 2.4
setuptools = { version = ">=70", python = ">=3.12" }

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.0"

[tool.poetry.group.test.dependencies]
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
pytest-asyncio = "^0.23.6"

[tool.pytest.ini_options]
asyncio_mode = "auto"

[tool.mypy]
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
follow_imports = "normal"
ignore_missing_imports = true
no_implicit_reexport = true
show_error_codes = true
show_error_context = true
strict_equality = true
strict_optional = true
warn_redundant_casts = true
warn_return_any = true
warn_unused_ignores = true
16 changes: 8 additions & 8 deletions libs/colbert/ragstack_colbert/base_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from abc import ABC, abstractmethod
from typing import List, Optional, Tuple
from typing import List, Tuple

from .objects import Chunk, Vector

Expand Down Expand Up @@ -48,13 +48,13 @@ def delete_chunks(self, doc_ids: List[str]) -> bool:

@abstractmethod
async def aadd_chunks(
self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100
self, chunks: List[Chunk], concurrent_inserts: int = 100
) -> List[Tuple[str, int]]:
"""Stores a list of embedded text chunks in the vector store.

Args:
chunks (List[Chunk]): A list of `Chunk` instances to be stored.
concurrent_inserts (Optional[int]): How many concurrent inserts to make to
chunks: A list of `Chunk` instances to be stored.
concurrent_inserts: How many concurrent inserts to make to
the database. Defaults to 100.

Returns:
Expand All @@ -63,14 +63,14 @@ async def aadd_chunks(

@abstractmethod
async def adelete_chunks(
self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100
self, doc_ids: List[str], concurrent_deletes: int = 100
) -> bool:
"""Deletes chunks from the vector store based on their document id.

Args:
doc_ids (List[str]): A list of document identifiers specifying the chunks
doc_ids: A list of document identifiers specifying the chunks
to be deleted.
concurrent_deletes (Optional[int]): How many concurrent deletes to make
concurrent_deletes: How many concurrent deletes to make
to the database. Defaults to 100.

Returns:
Expand All @@ -96,7 +96,7 @@ async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk:

@abstractmethod
async def get_chunk_data(
self, doc_id: str, chunk_id: int, include_embedding: Optional[bool]
self, doc_id: str, chunk_id: int, include_embedding: bool = False
) -> Chunk:
"""Retrieve the text and metadata for a chunk.

Expand Down
14 changes: 7 additions & 7 deletions libs/colbert/ragstack_colbert/base_embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ def embed_texts(self, texts: List[str]) -> List[Embedding]:
def embed_query(
self,
query: str,
full_length_search: Optional[bool] = False,
query_maxlen: int = -1,
full_length_search: bool = False,
query_maxlen: Optional[int] = None,
) -> Embedding:
"""Embeds a single query text into its vector representation.

If the query has fewer than query_maxlen tokens it will be padded with BERT
special [mast] tokens.

Args:
query (str): The query text to encode.
full_length_search (Optional[bool]): Indicates whether to encode the
query: The query text to encode.
full_length_search: Indicates whether to encode the
query for a full-length search. Defaults to False.
query_maxlen (int): The fixed length for the query token embedding.
If -1, uses a dynamically calculated value.
query_maxlen: The fixed length for the query token embedding.
If None, uses a dynamically calculated value.

Returns:
Embedding: A vector embedding representation of the query text
A vector embedding representation of the query text
"""
52 changes: 26 additions & 26 deletions libs/colbert/ragstack_colbert/base_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def embedding_search(
self,
query_embedding: Embedding,
k: Optional[int] = None,
include_embedding: Optional[bool] = False,
include_embedding: bool = False,
**kwargs: Any,
) -> List[Tuple[Chunk, float]]:
"""Search for relevant text chunks based on a query embedding.
Expand All @@ -34,16 +34,16 @@ def embedding_search(
store, ranked by relevance or other metrics.

Args:
query_embedding (Embedding): The query embedding to search for relevant
query_embedding: The query embedding to search for relevant
text chunks.
k (Optional[int]): The number of top results to retrieve.
include_embedding (Optional[bool]): Optional (default False) flag to
k: The number of top results to retrieve.
include_embedding: Optional (default False) flag to
include the embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require
**kwargs: Additional parameters that implementations might require
for customized retrieval operations.

Returns:
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
A list of retrieved Chunk, float Tuples,
each representing a text chunk that is relevant to the query,
along with its similarity score.
"""
Expand All @@ -54,7 +54,7 @@ async def aembedding_search(
self,
query_embedding: Embedding,
k: Optional[int] = None,
include_embedding: Optional[bool] = False,
include_embedding: bool = False,
**kwargs: Any,
) -> List[Tuple[Chunk, float]]:
"""Search for relevant text chunks based on a query embedding.
Expand All @@ -63,16 +63,16 @@ async def aembedding_search(
store, ranked by relevance or other metrics.

Args:
query_embedding (Embedding): The query embedding to search for relevant
query_embedding: The query embedding to search for relevant
text chunks.
k (Optional[int]): The number of top results to retrieve.
include_embedding (Optional[bool]): Optional (default False) flag to
k: The number of top results to retrieve.
include_embedding: Optional (default False) flag to
include the embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require
**kwargs: Additional parameters that implementations might require
for customized retrieval operations.

Returns:
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
A list of retrieved Chunk, float Tuples,
each representing a text chunk that is relevant to the query,
along with its similarity score.
"""
Expand All @@ -84,7 +84,7 @@ def text_search(
query_text: str,
k: Optional[int] = None,
query_maxlen: Optional[int] = None,
include_embedding: Optional[bool] = False,
include_embedding: bool = False,
**kwargs: Any,
) -> List[Tuple[Chunk, float]]:
"""Search for relevant text chunks based on a query text.
Expand All @@ -93,17 +93,17 @@ def text_search(
store, ranked by relevance or other metrics.

Args:
query_text (str): The query text to search for relevant text chunks.
k (Optional[int]): The number of top results to retrieve.
query_maxlen (Optional[int]): The maximum length of the query to consider.
query_text: The query text to search for relevant text chunks.
k: The number of top results to retrieve.
query_maxlen: The maximum length of the query to consider.
If None, the maxlen will be dynamically generated.
include_embedding (Optional[bool]): Optional (default False) flag to
include_embedding: Optional (default False) flag to
include the embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require
**kwargs: Additional parameters that implementations might require
for customized retrieval operations.

Returns:
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
A list of retrieved Chunk, float Tuples,
each representing a text chunk that is relevant to the query,
along with its similarity score.
"""
Expand All @@ -115,7 +115,7 @@ async def atext_search(
query_text: str,
k: Optional[int] = None,
query_maxlen: Optional[int] = None,
include_embedding: Optional[bool] = False,
include_embedding: bool = False,
**kwargs: Any,
) -> List[Tuple[Chunk, float]]:
"""Search for relevant text chunks based on a query text.
Expand All @@ -124,17 +124,17 @@ async def atext_search(
store, ranked by relevance or other metrics.

Args:
query_text (str): The query text to search for relevant text chunks.
k (Optional[int]): The number of top results to retrieve.
query_maxlen (Optional[int]): The maximum length of the query to consider.
query_text: The query text to search for relevant text chunks.
k: The number of top results to retrieve.
query_maxlen: The maximum length of the query to consider.
If None, the maxlen will be dynamically generated.
include_embedding (Optional[bool]): Optional (default False) flag to
include_embedding: Optional (default False) flag to
include the embedding vectors in the returned chunks
**kwargs (Any): Additional parameters that implementations might require
**kwargs: Additional parameters that implementations might require
for customized retrieval operations.

Returns:
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
A list of retrieved Chunk, float Tuples,
each representing a text chunk that is relevant to the query,
along with its similarity score.
"""
22 changes: 11 additions & 11 deletions libs/colbert/ragstack_colbert/base_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ def delete_chunks(self, doc_ids: List[str]) -> bool:
# handles LlamaIndex add
@abstractmethod
async def aadd_chunks(
self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100
self, chunks: List[Chunk], concurrent_inserts: int = 100
) -> List[Tuple[str, int]]:
"""Stores a list of embedded text chunks in the vector store.

Args:
chunks (List[Chunk]): A list of `Chunk` instances to be stored.
concurrent_inserts (Optional[int]): How many concurrent inserts to make to
chunks: A list of `Chunk` instances to be stored.
concurrent_inserts: How many concurrent inserts to make to
the database. Defaults to 100.

Returns:
Expand All @@ -107,20 +107,20 @@ async def aadd_texts(
texts: List[str],
metadatas: Optional[List[Metadata]],
doc_id: Optional[str] = None,
concurrent_inserts: Optional[int] = 100,
concurrent_inserts: int = 100,
) -> List[Tuple[str, int]]:
"""Adds text chunks to the vector store.

Embeds and stores a list of text chunks and optional metadata into the vector
store.

Args:
texts (List[str]): The list of text chunks to be embedded
metadatas (Optional[List[Metadata]])): An optional list of Metadata to be
texts: The list of text chunks to be embedded
metadatas: An optional list of Metadata to be
stored. If provided, these are set 1 to 1 with the texts list.
doc_id (Optional[str]): The document id associated with the texts.
doc_id: The document id associated with the texts.
If not provided, it is generated.
concurrent_inserts (Optional[int]): How many concurrent inserts to make to
concurrent_inserts: How many concurrent inserts to make to
the database. Defaults to 100.

Returns:
Expand All @@ -130,14 +130,14 @@ async def aadd_texts(
# handles LangChain and LlamaIndex delete
@abstractmethod
async def adelete_chunks(
self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100
self, doc_ids: List[str], concurrent_deletes: int = 100
) -> bool:
"""Deletes chunks from the vector store based on their document id.

Args:
doc_ids (List[str]): A list of document identifiers specifying the chunks
doc_ids: A list of document identifiers specifying the chunks
to be deleted.
concurrent_deletes (Optional[int]): How many concurrent deletes to make to
concurrent_deletes: How many concurrent deletes to make to
the database. Defaults to 100.

Returns:
Expand Down
Loading