Skip to content

Commit c16cbd5

Browse files
authored
Add type checking of ragstack-llama-index (#614)
1 parent e78d791 commit c16cbd5

File tree

6 files changed

+68
-37
lines changed

6 files changed

+68
-37
lines changed

libs/llamaindex/pyproject.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,26 @@ google = ["llama-index-llms-gemini", "llama-index-multi-modal-llms-gemini", "lla
4444
azure = ["llama-index-llms-azure-openai", "llama-index-embeddings-azure-openai"]
4545
bedrock = ["llama-index-llms-bedrock", "llama-index-embeddings-bedrock"]
4646

47+
[tool.poetry.group.dev.dependencies]
48+
mypy = "^1.11.0"
49+
4750
[tool.poetry.group.test.dependencies]
4851
ragstack-ai-tests-utils = { path = "../tests-utils", develop = true }
4952
ragstack-ai-colbert = { path = "../colbert", develop = true }
53+
54+
[tool.mypy]
55+
disallow_any_generics = true
56+
disallow_incomplete_defs = true
57+
disallow_untyped_calls = true
58+
disallow_untyped_decorators = true
59+
disallow_untyped_defs = true
60+
follow_imports = "normal"
61+
ignore_missing_imports = true
62+
no_implicit_reexport = true
63+
show_error_codes = true
64+
show_error_context = true
65+
strict_equality = true
66+
strict_optional = true
67+
warn_redundant_casts = true
68+
warn_return_any = true
69+
warn_unused_ignores = true

libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, List, Optional, Tuple
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
22

33
from llama_index.core.callbacks.base import CallbackManager
44
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
@@ -27,7 +27,7 @@ def __init__(
2727
retriever: ColbertBaseRetriever,
2828
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
2929
callback_manager: Optional[CallbackManager] = None,
30-
object_map: Optional[dict] = None,
30+
object_map: Optional[Dict[str, Any]] = None,
3131
verbose: bool = False,
3232
query_maxlen: int = -1,
3333
) -> None:
@@ -51,6 +51,6 @@ def _retrieve(
5151
query_maxlen=self._query_maxlen,
5252
)
5353
return [
54-
NodeWithScore(node=TextNode(text=c.text, metadata=c.metadata), score=s)
54+
NodeWithScore(node=TextNode(text=c.text, extra_info=c.metadata), score=s)
5555
for (c, s) in chunk_scores
5656
]

libs/llamaindex/tests/integration_tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from _pytest.fixtures import FixtureRequest
23
from cassandra.cluster import Session
34
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore
45

@@ -17,7 +18,7 @@ def astra_db() -> AstraDBTestStore:
1718

1819

1920
@pytest.fixture()
20-
def session(request) -> Session:
21+
def session(request: FixtureRequest) -> Session:
2122
test_store = request.getfixturevalue(request.param)
2223
session = test_store.create_cassandra_session()
2324
session.default_timeout = 180

libs/llamaindex/tests/integration_tests/test_colbert.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from cassandra.cluster import Session
66
from llama_index.core import Settings, get_response_synthesizer
77
from llama_index.core.ingestion import IngestionPipeline
8+
from llama_index.core.llms import MockLLM
89
from llama_index.core.query_engine import RetrieverQueryEngine
9-
from llama_index.core.schema import Document, NodeWithScore
10+
from llama_index.core.schema import Document, NodeWithScore, QueryBundle
1011
from llama_index.core.text_splitter import SentenceSplitter
1112
from ragstack_colbert import (
1213
CassandraDatabase,
@@ -20,7 +21,7 @@
2021
logging.getLogger("cassandra").setLevel(logging.ERROR)
2122

2223

23-
def validate_retrieval(results: List[NodeWithScore], key_value: str):
24+
def validate_retrieval(results: List[NodeWithScore], key_value: str) -> bool:
2425
passed = False
2526
for result in results:
2627
if key_value in result.text:
@@ -29,7 +30,7 @@ def validate_retrieval(results: List[NodeWithScore], key_value: str):
2930

3031

3132
@pytest.mark.parametrize("session", ["astra_db"], indirect=["session"]) # "cassandra",
32-
def test_sync(session: Session):
33+
def test_sync(session: Session) -> None:
3334
table_name = "LlamaIndex_colbert_sync"
3435

3536
batch_size = 5 # 640 recommended for production use
@@ -47,15 +48,15 @@ def test_sync(session: Session):
4748
embedding_model=embedding_model,
4849
)
4950

50-
docs: List[Document] = []
51+
docs = []
5152
docs.append(
5253
Document(
53-
text=TestData.marine_animals_text(), metadata={"name": "marine_animals"}
54+
text=TestData.marine_animals_text(), extra_info={"name": "marine_animals"}
5455
)
5556
)
5657
docs.append(
5758
Document(
58-
text=TestData.nebula_voyager_text(), metadata={"name": "nebula_voyager"}
59+
text=TestData.nebula_voyager_text(), extra_info={"name": "nebula_voyager"}
5960
)
6061
)
6162

@@ -64,20 +65,20 @@ def test_sync(session: Session):
6465

6566
nodes = pipeline.run(documents=docs)
6667

67-
docs: Dict[str, Tuple[List[str], List[Metadata]]] = {}
68+
docs2: Dict[str, Tuple[List[str], List[Metadata]]] = {}
6869

6970
for node in nodes:
7071
doc_id = node.metadata["name"]
71-
if doc_id not in docs:
72-
docs[doc_id] = ([], [])
73-
docs[doc_id][0].append(node.text)
74-
docs[doc_id][1].append(node.metadata)
72+
if doc_id not in docs2:
73+
docs2[doc_id] = ([], [])
74+
docs2[doc_id][0].append(node.text)
75+
docs2[doc_id][1].append(node.metadata)
7576

7677
logging.debug("Starting to embed ColBERT docs and save them to the database")
7778

78-
for doc_id in docs:
79-
texts = docs[doc_id][0]
80-
metadatas = docs[doc_id][1]
79+
for doc_id in docs2:
80+
texts = docs2[doc_id][0]
81+
metadatas = docs2[doc_id][1]
8182

8283
logging.debug("processing %s that has %s chunks", doc_id, len(texts))
8384

@@ -87,22 +88,24 @@ def test_sync(session: Session):
8788
retriever=vector_store.as_retriever(), similarity_top_k=5
8889
)
8990

90-
Settings.llm = None
91+
Settings.llm = MockLLM()
9192

9293
response_synthesizer = get_response_synthesizer()
9394

94-
pipeline = RetrieverQueryEngine(
95+
pipeline2 = RetrieverQueryEngine(
9596
retriever=retriever,
9697
response_synthesizer=response_synthesizer,
9798
)
9899

99-
results = pipeline.retrieve("Who developed the Astroflux Navigator?")
100+
results = pipeline2.retrieve(QueryBundle("Who developed the Astroflux Navigator?"))
100101
assert validate_retrieval(results, key_value="Astroflux Navigator")
101102

102-
results = pipeline.retrieve(
103-
"Describe the phenomena known as 'Chrono-spatial Echoes'"
103+
results = pipeline2.retrieve(
104+
QueryBundle("Describe the phenomena known as 'Chrono-spatial Echoes'")
104105
)
105106
assert validate_retrieval(results, key_value="Chrono-spatial Echoes")
106107

107-
results = pipeline.retrieve("How do anglerfish adapt to the deep ocean's darkness?")
108+
results = pipeline2.retrieve(
109+
QueryBundle("How do anglerfish adapt to the deep ocean's darkness?")
110+
)
108111
assert validate_retrieval(results, key_value="anglerfish")

libs/llamaindex/tests/unit_tests/test_import.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import importlib
2+
from typing import Any, Callable
23

34

4-
def test_import():
5+
def test_import() -> None:
56
import astrapy # noqa: F401
67
import cassio # noqa: F401
78
import openai # noqa: F401
@@ -11,25 +12,25 @@ def test_import():
1112
from llama_index.vector_stores.cassandra import CassandraVectorStore # noqa: F401
1213

1314

14-
def check_no_import(fn: callable):
15+
def check_no_import(fn: Callable[[], Any]) -> None:
1516
try:
1617
fn()
1718
raise RuntimeError("Should have failed to import")
1819
except ImportError:
1920
pass
2021

2122

22-
def test_not_import():
23+
def test_not_import() -> None:
2324
check_no_import(lambda: importlib.import_module("langchain.vectorstores"))
2425
check_no_import(lambda: importlib.import_module("langchain_astradb"))
2526
check_no_import(lambda: importlib.import_module("langchain_core"))
2627
check_no_import(lambda: importlib.import_module("langsmith"))
2728

2829

29-
def test_meta():
30+
def test_meta() -> None:
3031
from importlib import metadata
3132

32-
def check_meta(package: str):
33+
def check_meta(package: str) -> None:
3334
meta = metadata.metadata(package)
3435
assert meta["version"]
3536
assert meta["license"] == "BUSL-1.1"

libs/llamaindex/tox.ini

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,30 @@
11
[tox]
22
min_version = 4.0
3-
envlist = py311
3+
envlist = type, unit-tests, integration-tests
4+
5+
[testenv]
6+
description = install dependencies
7+
skip_install = true
8+
allowlist_externals = poetry
9+
commands_pre =
10+
poetry env use system
11+
poetry install -E colbert
412

513
[testenv:unit-tests]
614
description = run unit tests
7-
deps =
8-
poetry
915
commands =
10-
poetry install
11-
poetry build
1216
poetry run pytest --disable-warnings {toxinidir}/tests/unit_tests
1317

1418
[testenv:integration-tests]
1519
description = run integration tests
16-
deps =
17-
poetry
1820
pass_env =
1921
ASTRA_DB_TOKEN
2022
ASTRA_DB_ID
2123
ASTRA_DB_ENV
2224
commands =
23-
poetry install -E colbert
2425
poetry run pytest --disable-warnings {toxinidir}/tests/integration_tests
26+
27+
[testenv:type]
28+
description = run type checking
29+
commands =
30+
poetry run mypy {toxinidir}

0 commit comments

Comments
 (0)