Skip to content

Commit f10a2cc

Browse files
authored
Add some None return types (#636)
1 parent 626b39f commit f10a2cc

28 files changed

+68
-68
lines changed

examples/evaluation/tru_shared.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def get_astra_vector_store(framework: Framework, collection_name: str):
192192
)
193193

194194

195-
def execute_query(framework: Framework, pipeline, query):
195+
def execute_query(framework: Framework, pipeline, query) -> None:
196196
if framework == Framework.LANG_CHAIN:
197197
pipeline.invoke(query)
198198
elif framework == Framework.LLAMA_INDEX:
@@ -204,7 +204,7 @@ def execute_query(framework: Framework, pipeline, query):
204204

205205

206206
# runs the pipeline across all queries in all known datasets
207-
def execute_experiment(framework: Framework, pipeline, experiment_name: str):
207+
def execute_experiment(framework: Framework, pipeline, experiment_name: str) -> None:
208208
init_tru()
209209

210210
# use a short uuid to ensure that multiple experiments with the same name don't

examples/notebooks/advancedRAG.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"import textwrap\n",
102102
"\n",
103103
"\n",
104-
"def pprint_docs(docs):\n",
104+
"def pprint_docs(docs) -> None:\n",
105105
" print(\n",
106106
" f\"\\n{'-' * 70}\\n\".join(\n",
107107
" [\n",
@@ -112,7 +112,7 @@
112112
" )\n",
113113
"\n",
114114
"\n",
115-
"def pprint_result(result):\n",
115+
"def pprint_result(result) -> None:\n",
116116
" print(\"Answer: \" + \"\\n\".join(textwrap.wrap(result)))"
117117
]
118118
},
@@ -244,7 +244,7 @@
244244
"from langchain.callbacks import get_openai_callback\n",
245245
"\n",
246246
"\n",
247-
"def do_retrieval(chain):\n",
247+
"def do_retrieval(chain) -> None:\n",
248248
" for i in range(len(questions)):\n",
249249
" print(\"-\" * 40)\n",
250250
" print(f\"Question: {questions[i]}\\n\")\n",

examples/notebooks/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_required_env(name) -> str:
2929
)
3030

3131

32-
def try_delete_with_backoff(collection: str, sleep=1, max_tries=2):
32+
def try_delete_with_backoff(collection: str, sleep=1, max_tries=2) -> None:
3333
try:
3434
logging.info("deleting collection %s", collection)
3535
response = client.delete_collection(collection)
@@ -46,7 +46,7 @@ def try_delete_with_backoff(collection: str, sleep=1, max_tries=2):
4646
try_delete_with_backoff(collection, sleep * 2, max_tries)
4747

4848

49-
def before_notebook():
49+
def before_notebook() -> None:
5050
collections = client.get_collections().get("status").get("collections")
5151
logging.info("Existing collections: %s", collections)
5252
for collection in collections:

examples/notebooks/nemo_guardrails.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@
192192
"\"\"\"\n",
193193
"\n",
194194
"\n",
195-
"def yaml_config(engine, model):\n",
195+
"def yaml_config(engine, model) -> str:\n",
196196
" return f\"\"\"\n",
197197
" models:\n",
198198
" - type: main\n",
@@ -244,7 +244,7 @@
244244
" return ActionResult(return_value=answer, context_updates=context_updates)\n",
245245
"\n",
246246
"\n",
247-
"def init(app: LLMRails):\n",
247+
"def init(app: LLMRails) -> None:\n",
248248
" app.register_action(rag, \"rag\")"
249249
]
250250
},

libs/e2e-tests/e2e_tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,13 @@ def pytest_runtest_makereport(item, call):
156156
os.environ["RAGSTACK_E2E_TESTS_TEST_START"] = str(time.perf_counter_ns())
157157

158158

159-
def set_current_test_info(test_name: str, test_info: str):
159+
def set_current_test_info(test_name: str, test_info: str) -> None:
160160
test_info = test_info.replace("_", "-")
161161
os.environ["RAGSTACK_E2E_TESTS_TEST_INFO"] = f"{test_name}::{test_info}"
162162

163163

164164
@pytest.hookimpl()
165-
def pytest_sessionfinish():
165+
def pytest_sessionfinish() -> None:
166166
logging.info("All tests report:")
167167
logging.info("\n".join(all_report_lines))
168168
logging.info("Failed tests report:")
@@ -184,7 +184,7 @@ def pytest_sessionfinish():
184184
_report_to_file("", "llamaindex-tests-report.txt", llamaindex_report_lines)
185185

186186

187-
def _report_to_file(stats_str: str, filename: str, report_lines: list):
187+
def _report_to_file(stats_str: str, filename: str, report_lines: list) -> None:
188188
report_lines.sort()
189189
with open(filename, "w") as f:
190190
if stats_str:

libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _colang() -> str:
5353

5454

5555
class NeMoRag:
56-
def __init__(self, retriever):
56+
def __init__(self, retriever) -> None:
5757
self.retriever = retriever
5858

5959
async def rag_using_lc(self, context: dict, llm: BaseLLM) -> ActionResult:
@@ -77,7 +77,7 @@ async def rag_using_lc(self, context: dict, llm: BaseLLM) -> ActionResult:
7777

7878
return ActionResult(return_value=answer, context_updates=context_updates)
7979

80-
def init(self, app: LLMRails):
80+
def init(self, app: LLMRails) -> None:
8181
app.register_action(self.rag_using_lc, "rag")
8282

8383

libs/e2e-tests/e2e_tests/langchain/test_astra.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
MINIMUM_ACCEPTABLE_SCORE = 0.1
2626

2727

28-
def test_basic_vector_search(vectorstore: AstraDBVectorStore):
28+
def test_basic_vector_search(vectorstore: AstraDBVectorStore) -> None:
2929
print("Running test_basic_vector_search")
3030
vectorstore.add_texts(["RAGStack is a framework to run LangChain in production"])
3131
retriever = vectorstore.as_retriever()
3232
assert len(retriever.get_relevant_documents("RAGStack")) > 0
3333

3434

35-
def test_ingest_errors(vectorstore: AstraDBVectorStore):
35+
def test_ingest_errors(vectorstore: AstraDBVectorStore) -> None:
3636
print("Running test_ingestion")
3737

3838
empty_text = ""
@@ -75,7 +75,7 @@ def test_ingest_errors(vectorstore: AstraDBVectorStore):
7575
)
7676

7777

78-
def test_wrong_connection_parameters(vectorstore: AstraDBVectorStore):
78+
def test_wrong_connection_parameters(vectorstore: AstraDBVectorStore) -> None:
7979
try:
8080
AstraDBVectorStore(
8181
collection_name="something",
@@ -109,7 +109,7 @@ def test_wrong_connection_parameters(vectorstore: AstraDBVectorStore):
109109
)
110110

111111

112-
def test_basic_metadata_filtering_no_vector(vectorstore: AstraDBVectorStore):
112+
def test_basic_metadata_filtering_no_vector(vectorstore: AstraDBVectorStore) -> None:
113113
collection = vectorstore.collection
114114
vectorstore.add_texts(
115115
texts=["RAGStack is a framework to run LangChain in production"],
@@ -202,7 +202,7 @@ def test_basic_metadata_filtering_no_vector(vectorstore: AstraDBVectorStore):
202202
)
203203

204204

205-
def verify_document(document, expected_content, expected_metadata):
205+
def verify_document(document, expected_content, expected_metadata) -> None:
206206
if isinstance(document, Document):
207207
assert document.page_content == expected_content
208208
assert document.metadata == expected_metadata
@@ -211,7 +211,7 @@ def verify_document(document, expected_content, expected_metadata):
211211
assert document.get("metadata") == expected_metadata
212212

213213

214-
def test_vector_search_with_metadata(vectorstore: VectorStore):
214+
def test_vector_search_with_metadata(vectorstore: VectorStore) -> None:
215215
print("Running test_vector_search_with_metadata")
216216

217217
document_ids = vectorstore.add_texts(
@@ -409,7 +409,7 @@ def test_vector_search_with_metadata(vectorstore: VectorStore):
409409

410410

411411
@pytest.mark.skip()
412-
def test_stress_astra():
412+
def test_stress_astra() -> None:
413413
handler = AstraDBVectorStoreHandler(VectorStoreImplementation.ASTRADB)
414414
while True:
415415
context = handler.before_test()
@@ -423,7 +423,7 @@ def test_stress_astra():
423423

424424

425425
class MockEmbeddings(Embeddings):
426-
def __init__(self):
426+
def __init__(self) -> None:
427427
self.embedded_documents = None
428428
self.embedded_query = None
429429

libs/e2e-tests/e2e_tests/langchain/test_cassandra_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@pytest.mark.usefixtures("cassandra")
17-
def test_tool_with_openai_tool():
17+
def test_tool_with_openai_tool() -> None:
1818
session = cassio.config.resolve_session()
1919
session.execute("DROP TABLE IF EXISTS default_keyspace.tool_table_users;")
2020

libs/e2e-tests/e2e_tests/langchain/test_compatibility_rag.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def llm():
304304
),
305305
],
306306
)
307-
def test_rag(test_case, vector_store, embedding, llm, request, record_property):
307+
def test_rag(test_case, vector_store, embedding, llm, request, record_property) -> None:
308308
set_current_test_info(
309309
"langchain::" + test_case,
310310
f"{llm},{embedding},{vector_store}",
@@ -327,7 +327,7 @@ def _run_test(
327327
embedding_fn,
328328
resolved_llm,
329329
record_property,
330-
):
330+
) -> None:
331331
# NeMo guardrails running only with certain LLMs
332332
if test_case == "nemo_guardrails" and not resolved_llm["nemo_config"]:
333333
skip_test_due_to_implementation_not_supported("nemo_guardrails")
@@ -408,7 +408,7 @@ def gemini_pro_llm():
408408
("vertex_gemini_multimodal_embedding", "gemini_flash_llm"),
409409
],
410410
)
411-
def test_multimodal(vector_store, embedding, llm, request, record_property):
411+
def test_multimodal(vector_store, embedding, llm, request, record_property) -> None:
412412
set_current_test_info(
413413
"langchain::multimodal_rag",
414414
f"{llm},{embedding},{vector_store}",
@@ -487,7 +487,7 @@ def embed_query(self, text: str) -> list[float]:
487487

488488

489489
@pytest.mark.parametrize("chat", ["vertex_gemini_pro_llm", "gemini_pro_llm"])
490-
def test_chat(chat, request, record_property):
490+
def test_chat(chat, request, record_property) -> None:
491491
set_current_test_info(
492492
"langchain::chat",
493493
chat,

libs/e2e-tests/e2e_tests/langchain/test_document_loaders.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from e2e_tests.test_utils.vector_store_handler import VectorStoreImplementation
1818

1919

20-
def set_current_test_info_document_loader(doc_loader: str):
20+
def set_current_test_info_document_loader(doc_loader: str) -> None:
2121
set_current_test_info("langchain::document_loader", doc_loader)
2222

2323

24-
def test_csv_loader():
24+
def test_csv_loader() -> None:
2525
set_current_test_info_document_loader("csv")
2626
with tempfile.NamedTemporaryFile(mode="w+", suffix=".csv") as temp_csv_file:
2727
with open(temp_csv_file.name, "w") as write:
@@ -41,7 +41,7 @@ def test_csv_loader():
4141
assert doc1.metadata == {"row": 0, "source": temp_csv_file.name}
4242

4343

44-
def test_web_based_loader():
44+
def test_web_based_loader() -> None:
4545
set_current_test_info_document_loader("web")
4646
loader = WebBaseLoader(
4747
["https://langstream.ai/changelog/", "https://langstream.ai/faq/"]
@@ -69,7 +69,7 @@ def test_web_based_loader():
6969
}
7070

7171

72-
def test_s3_loader():
72+
def test_s3_loader() -> None:
7373
set_current_test_info_document_loader("s3")
7474
aws_region = "us-east-1"
7575
bucket_name = f"ragstack-ci-{uuid.uuid4()}"
@@ -95,7 +95,7 @@ def test_s3_loader():
9595
bucket.delete()
9696

9797

98-
def test_azure_blob_doc_loader():
98+
def test_azure_blob_doc_loader() -> None:
9999
set_current_test_info_document_loader("azure")
100100
from azure.storage.blob import BlobClient
101101

0 commit comments

Comments
 (0)