Skip to content

Commit bc150a9

Browse files
authored
Add type checking of ragstack-colbert (#603)
1 parent 8a137ef commit bc150a9

20 files changed

+279
-442
lines changed

.github/workflows/ci-unit-tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ jobs:
8686
- name: "Setup: Python 3.11"
8787
uses: ./.github/actions/setup-python
8888

89+
- name: "Type check (colbert)"
90+
run: tox -e type -c libs/colbert && rm -rf libs/colbert/.tox
91+
8992
- name: "Type check (knowledge-graph)"
9093
run: tox -e type -c libs/knowledge-graph && rm -rf libs/knowledge-graph/.tox
9194

libs/colbert/pyproject.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,29 @@ pydantic = "^2.7.1"
2121
# Remove when we upgrade to pytorch 2.4
2222
setuptools = { version = ">=70", python = ">=3.12" }
2323

24+
[tool.poetry.group.dev.dependencies]
25+
mypy = "^1.11.0"
2426

2527
[tool.poetry.group.test.dependencies]
2628
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
2729
pytest-asyncio = "^0.23.6"
2830

2931
[tool.pytest.ini_options]
3032
asyncio_mode = "auto"
33+
34+
[tool.mypy]
35+
disallow_any_generics = true
36+
disallow_incomplete_defs = true
37+
disallow_untyped_calls = true
38+
disallow_untyped_decorators = true
39+
disallow_untyped_defs = true
40+
follow_imports = "normal"
41+
ignore_missing_imports = true
42+
no_implicit_reexport = true
43+
show_error_codes = true
44+
show_error_context = true
45+
strict_equality = true
46+
strict_optional = true
47+
warn_redundant_casts = true
48+
warn_return_any = true
49+
warn_unused_ignores = true

libs/colbert/ragstack_colbert/base_database.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from abc import ABC, abstractmethod
9-
from typing import List, Optional, Tuple
9+
from typing import List, Tuple
1010

1111
from .objects import Chunk, Vector
1212

@@ -48,13 +48,13 @@ def delete_chunks(self, doc_ids: List[str]) -> bool:
4848

4949
@abstractmethod
5050
async def aadd_chunks(
51-
self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100
51+
self, chunks: List[Chunk], concurrent_inserts: int = 100
5252
) -> List[Tuple[str, int]]:
5353
"""Stores a list of embedded text chunks in the vector store.
5454
5555
Args:
56-
chunks (List[Chunk]): A list of `Chunk` instances to be stored.
57-
concurrent_inserts (Optional[int]): How many concurrent inserts to make to
56+
chunks: A list of `Chunk` instances to be stored.
57+
concurrent_inserts: How many concurrent inserts to make to
5858
the database. Defaults to 100.
5959
6060
Returns:
@@ -63,14 +63,14 @@ async def aadd_chunks(
6363

6464
@abstractmethod
6565
async def adelete_chunks(
66-
self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100
66+
self, doc_ids: List[str], concurrent_deletes: int = 100
6767
) -> bool:
6868
"""Deletes chunks from the vector store based on their document id.
6969
7070
Args:
71-
doc_ids (List[str]): A list of document identifiers specifying the chunks
71+
doc_ids: A list of document identifiers specifying the chunks
7272
to be deleted.
73-
concurrent_deletes (Optional[int]): How many concurrent deletes to make
73+
concurrent_deletes: How many concurrent deletes to make
7474
to the database. Defaults to 100.
7575
7676
Returns:
@@ -96,7 +96,7 @@ async def get_chunk_embedding(self, doc_id: str, chunk_id: int) -> Chunk:
9696

9797
@abstractmethod
9898
async def get_chunk_data(
99-
self, doc_id: str, chunk_id: int, include_embedding: Optional[bool]
99+
self, doc_id: str, chunk_id: int, include_embedding: bool = False
100100
) -> Chunk:
101101
"""Retrieve the text and metadata for a chunk.
102102

libs/colbert/ragstack_colbert/base_embedding_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@ def embed_texts(self, texts: List[str]) -> List[Embedding]:
3535
def embed_query(
3636
self,
3737
query: str,
38-
full_length_search: Optional[bool] = False,
39-
query_maxlen: int = -1,
38+
full_length_search: bool = False,
39+
query_maxlen: Optional[int] = None,
4040
) -> Embedding:
4141
"""Embeds a single query text into its vector representation.
4242
4343
If the query has fewer than query_maxlen tokens it will be padded with BERT
4444
special [mast] tokens.
4545
4646
Args:
47-
query (str): The query text to encode.
48-
full_length_search (Optional[bool]): Indicates whether to encode the
47+
query: The query text to encode.
48+
full_length_search: Indicates whether to encode the
4949
query for a full-length search. Defaults to False.
50-
query_maxlen (int): The fixed length for the query token embedding.
51-
If -1, uses a dynamically calculated value.
50+
query_maxlen: The fixed length for the query token embedding.
51+
If None, uses a dynamically calculated value.
5252
5353
Returns:
54-
Embedding: A vector embedding representation of the query text
54+
A vector embedding representation of the query text
5555
"""

libs/colbert/ragstack_colbert/base_retriever.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def embedding_search(
2525
self,
2626
query_embedding: Embedding,
2727
k: Optional[int] = None,
28-
include_embedding: Optional[bool] = False,
28+
include_embedding: bool = False,
2929
**kwargs: Any,
3030
) -> List[Tuple[Chunk, float]]:
3131
"""Search for relevant text chunks based on a query embedding.
@@ -34,16 +34,16 @@ def embedding_search(
3434
store, ranked by relevance or other metrics.
3535
3636
Args:
37-
query_embedding (Embedding): The query embedding to search for relevant
37+
query_embedding: The query embedding to search for relevant
3838
text chunks.
39-
k (Optional[int]): The number of top results to retrieve.
40-
include_embedding (Optional[bool]): Optional (default False) flag to
39+
k: The number of top results to retrieve.
40+
include_embedding: Optional (default False) flag to
4141
include the embedding vectors in the returned chunks
42-
**kwargs (Any): Additional parameters that implementations might require
42+
**kwargs: Additional parameters that implementations might require
4343
for customized retrieval operations.
4444
4545
Returns:
46-
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
46+
A list of retrieved Chunk, float Tuples,
4747
each representing a text chunk that is relevant to the query,
4848
along with its similarity score.
4949
"""
@@ -54,7 +54,7 @@ async def aembedding_search(
5454
self,
5555
query_embedding: Embedding,
5656
k: Optional[int] = None,
57-
include_embedding: Optional[bool] = False,
57+
include_embedding: bool = False,
5858
**kwargs: Any,
5959
) -> List[Tuple[Chunk, float]]:
6060
"""Search for relevant text chunks based on a query embedding.
@@ -63,16 +63,16 @@ async def aembedding_search(
6363
store, ranked by relevance or other metrics.
6464
6565
Args:
66-
query_embedding (Embedding): The query embedding to search for relevant
66+
query_embedding: The query embedding to search for relevant
6767
text chunks.
68-
k (Optional[int]): The number of top results to retrieve.
69-
include_embedding (Optional[bool]): Optional (default False) flag to
68+
k: The number of top results to retrieve.
69+
include_embedding: Optional (default False) flag to
7070
include the embedding vectors in the returned chunks
71-
**kwargs (Any): Additional parameters that implementations might require
71+
**kwargs: Additional parameters that implementations might require
7272
for customized retrieval operations.
7373
7474
Returns:
75-
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
75+
A list of retrieved Chunk, float Tuples,
7676
each representing a text chunk that is relevant to the query,
7777
along with its similarity score.
7878
"""
@@ -84,7 +84,7 @@ def text_search(
8484
query_text: str,
8585
k: Optional[int] = None,
8686
query_maxlen: Optional[int] = None,
87-
include_embedding: Optional[bool] = False,
87+
include_embedding: bool = False,
8888
**kwargs: Any,
8989
) -> List[Tuple[Chunk, float]]:
9090
"""Search for relevant text chunks based on a query text.
@@ -93,17 +93,17 @@ def text_search(
9393
store, ranked by relevance or other metrics.
9494
9595
Args:
96-
query_text (str): The query text to search for relevant text chunks.
97-
k (Optional[int]): The number of top results to retrieve.
98-
query_maxlen (Optional[int]): The maximum length of the query to consider.
96+
query_text: The query text to search for relevant text chunks.
97+
k: The number of top results to retrieve.
98+
query_maxlen: The maximum length of the query to consider.
9999
If None, the maxlen will be dynamically generated.
100-
include_embedding (Optional[bool]): Optional (default False) flag to
100+
include_embedding: Optional (default False) flag to
101101
include the embedding vectors in the returned chunks
102-
**kwargs (Any): Additional parameters that implementations might require
102+
**kwargs: Additional parameters that implementations might require
103103
for customized retrieval operations.
104104
105105
Returns:
106-
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
106+
A list of retrieved Chunk, float Tuples,
107107
each representing a text chunk that is relevant to the query,
108108
along with its similarity score.
109109
"""
@@ -115,7 +115,7 @@ async def atext_search(
115115
query_text: str,
116116
k: Optional[int] = None,
117117
query_maxlen: Optional[int] = None,
118-
include_embedding: Optional[bool] = False,
118+
include_embedding: bool = False,
119119
**kwargs: Any,
120120
) -> List[Tuple[Chunk, float]]:
121121
"""Search for relevant text chunks based on a query text.
@@ -124,17 +124,17 @@ async def atext_search(
124124
store, ranked by relevance or other metrics.
125125
126126
Args:
127-
query_text (str): The query text to search for relevant text chunks.
128-
k (Optional[int]): The number of top results to retrieve.
129-
query_maxlen (Optional[int]): The maximum length of the query to consider.
127+
query_text: The query text to search for relevant text chunks.
128+
k: The number of top results to retrieve.
129+
query_maxlen: The maximum length of the query to consider.
130130
If None, the maxlen will be dynamically generated.
131-
include_embedding (Optional[bool]): Optional (default False) flag to
131+
include_embedding: Optional (default False) flag to
132132
include the embedding vectors in the returned chunks
133-
**kwargs (Any): Additional parameters that implementations might require
133+
**kwargs: Additional parameters that implementations might require
134134
for customized retrieval operations.
135135
136136
Returns:
137-
List[Tuple[Chunk, float]]: A list of retrieved Chunk, float Tuples,
137+
A list of retrieved Chunk, float Tuples,
138138
each representing a text chunk that is relevant to the query,
139139
along with its similarity score.
140140
"""

libs/colbert/ragstack_colbert/base_vector_store.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ def delete_chunks(self, doc_ids: List[str]) -> bool:
8787
# handles LlamaIndex add
8888
@abstractmethod
8989
async def aadd_chunks(
90-
self, chunks: List[Chunk], concurrent_inserts: Optional[int] = 100
90+
self, chunks: List[Chunk], concurrent_inserts: int = 100
9191
) -> List[Tuple[str, int]]:
9292
"""Stores a list of embedded text chunks in the vector store.
9393
9494
Args:
95-
chunks (List[Chunk]): A list of `Chunk` instances to be stored.
96-
concurrent_inserts (Optional[int]): How many concurrent inserts to make to
95+
chunks: A list of `Chunk` instances to be stored.
96+
concurrent_inserts: How many concurrent inserts to make to
9797
the database. Defaults to 100.
9898
9999
Returns:
@@ -107,20 +107,20 @@ async def aadd_texts(
107107
texts: List[str],
108108
metadatas: Optional[List[Metadata]],
109109
doc_id: Optional[str] = None,
110-
concurrent_inserts: Optional[int] = 100,
110+
concurrent_inserts: int = 100,
111111
) -> List[Tuple[str, int]]:
112112
"""Adds text chunks to the vector store.
113113
114114
Embeds and stores a list of text chunks and optional metadata into the vector
115115
store.
116116
117117
Args:
118-
texts (List[str]): The list of text chunks to be embedded
119-
metadatas (Optional[List[Metadata]])): An optional list of Metadata to be
118+
texts: The list of text chunks to be embedded
119+
metadatas: An optional list of Metadata to be
120120
stored. If provided, these are set 1 to 1 with the texts list.
121-
doc_id (Optional[str]): The document id associated with the texts.
121+
doc_id: The document id associated with the texts.
122122
If not provided, it is generated.
123-
concurrent_inserts (Optional[int]): How many concurrent inserts to make to
123+
concurrent_inserts: How many concurrent inserts to make to
124124
the database. Defaults to 100.
125125
126126
Returns:
@@ -130,14 +130,14 @@ async def aadd_texts(
130130
# handles LangChain and LlamaIndex delete
131131
@abstractmethod
132132
async def adelete_chunks(
133-
self, doc_ids: List[str], concurrent_deletes: Optional[int] = 100
133+
self, doc_ids: List[str], concurrent_deletes: int = 100
134134
) -> bool:
135135
"""Deletes chunks from the vector store based on their document id.
136136
137137
Args:
138-
doc_ids (List[str]): A list of document identifiers specifying the chunks
138+
doc_ids: A list of document identifiers specifying the chunks
139139
to be deleted.
140-
concurrent_deletes (Optional[int]): How many concurrent deletes to make to
140+
concurrent_deletes: How many concurrent deletes to make to
141141
the database. Defaults to 100.
142142
143143
Returns:

0 commit comments

Comments
 (0)