Skip to content

Commit 9c47dc5

Browse files
authored
Add type checking of ragstack-ragulate (#615)
1 parent c16cbd5 commit 9c47dc5

26 files changed

+963
-816
lines changed

libs/ragulate/colbert_chunk_size_and_k.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import os
44
import time
55
from pathlib import Path
6-
from typing import List
6+
from typing import Any, List
77

88
from langchain.text_splitter import RecursiveCharacterTextSplitter
99
from langchain_community.document_loaders import UnstructuredFileLoader
1010
from langchain_core.output_parsers import StrOutputParser
1111
from langchain_core.prompts import ChatPromptTemplate
12-
from langchain_core.runnables import RunnablePassthrough
12+
from langchain_core.runnables import Runnable, RunnablePassthrough
1313
from langchain_openai import ChatOpenAI
1414
from ragstack_colbert import (
1515
CassandraDatabase,
@@ -74,7 +74,7 @@ def len_function(text: str) -> int:
7474
return len(tokenizer.tokenize(text))
7575

7676

77-
async def ingest(file_path: str, chunk_size: int, **_):
77+
async def ingest(file_path: str, chunk_size: int, **_: Any) -> None:
7878
doc_id = Path(file_path).name
7979

8080
chunk_overlap = min(chunk_size / 4, 64)
@@ -134,9 +134,9 @@ async def ingest(file_path: str, chunk_size: int, **_):
134134
)
135135

136136

137-
def query_pipeline(k: int, chunk_size: int, **_):
137+
def query_pipeline(k: int, chunk_size: int, **_: Any) -> Runnable[Any, Any]:
138138
vector_store = get_lc_vector_store(chunk_size=chunk_size)
139-
llm = ChatOpenAI(model_name=LLM_MODEL)
139+
llm = ChatOpenAI(model=LLM_MODEL)
140140

141141
# build a prompt
142142
prompt_template = """

libs/ragulate/open_ai_chunk_size_and_k.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
# ruff: noqa: D103, INP001, T201
22
import os
3+
from typing import Any
34

45
from langchain_astradb import AstraDBVectorStore
56
from langchain_community.document_loaders import UnstructuredFileLoader
67
from langchain_core.output_parsers import StrOutputParser
78
from langchain_core.prompts import ChatPromptTemplate
8-
from langchain_core.runnables import RunnablePassthrough
9+
from langchain_core.runnables import Runnable, RunnablePassthrough
910
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
1011
from langchain_text_splitters import RecursiveCharacterTextSplitter
1112

1213
EMBEDDING_MODEL = "text-embedding-3-small"
1314
LLM_MODEL = "gpt-3.5-turbo"
1415

1516

16-
def get_vector_store(chunk_size: int):
17+
def get_vector_store(chunk_size: int) -> AstraDBVectorStore:
1718
return AstraDBVectorStore(
1819
embedding=OpenAIEmbeddings(model=EMBEDDING_MODEL),
1920
collection_name=f"chunk_size_{chunk_size}",
@@ -22,7 +23,7 @@ def get_vector_store(chunk_size: int):
2223
)
2324

2425

25-
def ingest(file_path: str, chunk_size: int, **_):
26+
def ingest(file_path: str, chunk_size: int, **_: Any) -> None:
2627
vector_store = get_vector_store(chunk_size=chunk_size)
2728

2829
chunk_overlap = min(chunk_size / 4, 64)
@@ -40,9 +41,9 @@ def ingest(file_path: str, chunk_size: int, **_):
4041
vector_store.add_documents(split_docs)
4142

4243

43-
def query_pipeline(k: int, chunk_size: int, **_):
44+
def query_pipeline(k: int, chunk_size: int, **_: Any) -> Runnable[Any, Any]:
4445
vector_store = get_vector_store(chunk_size=chunk_size)
45-
llm = ChatOpenAI(model_name=LLM_MODEL)
46+
llm = ChatOpenAI(model=LLM_MODEL)
4647

4748
# build a prompt
4849
prompt_template = """

libs/ragulate/poetry.lock

Lines changed: 806 additions & 708 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/ragulate/pyproject.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ langchain-community = "0.0.38"
3535
langchain-core = "0.1.52"
3636
langchain-openai = "0.1.3"
3737
pytest = "^8.2.2"
38+
mypy = "^1.11.0"
39+
types-pyyaml = "^6.0.1"
40+
types-aiofiles = "^23.2.0.0"
3841

3942
[build-system]
4043
requires = ["poetry-core", "wheel"]
@@ -44,3 +47,20 @@ build-backend = "poetry.core.masonry.api"
4447
ragulate = "ragstack_ragulate.cli:main"
4548
test_unit = "scripts.test_unit_runner:main"
4649
test_integration = "scripts.test_integration_runner:main"
50+
51+
[tool.mypy]
52+
disallow_any_generics = true
53+
disallow_incomplete_defs = true
54+
disallow_untyped_calls = true
55+
disallow_untyped_decorators = true
56+
disallow_untyped_defs = true
57+
follow_imports = "normal"
58+
ignore_missing_imports = true
59+
no_implicit_reexport = true
60+
show_error_codes = true
61+
show_error_context = true
62+
strict_equality = true
63+
strict_optional = true
64+
warn_redundant_casts = true
65+
warn_return_any = true
66+
warn_unused_ignores = true

libs/ragulate/ragstack_ragulate/analysis.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
from typing import List
1+
from typing import Any, Dict, List, Tuple
22

33
import matplotlib.pyplot as plt
44
import numpy as np
55
import pandas as pd
66
import plotly.graph_objects as go
77
import seaborn as sns
8-
from pandas import DataFrame
98
from plotly.io import write_image
109

1110
from .utils import get_tru
@@ -14,7 +13,7 @@
1413
class Analysis:
1514
"""Analysis class."""
1615

17-
def get_all_data(self, recipes: List[str]) -> DataFrame:
16+
def get_all_data(self, recipes: List[str]) -> Tuple[pd.DataFrame, List[str]]:
1817
"""Get all data from the recipes."""
1918
df_all = pd.DataFrame()
2019

@@ -55,9 +54,11 @@ def get_all_data(self, recipes: List[str]) -> DataFrame:
5554

5655
return reset_df, list(set(all_metrics))
5756

58-
def calculate_statistics(self, df: pd.DataFrame, metrics: list):
57+
def calculate_statistics(
58+
self, df: pd.DataFrame, metrics: List[str]
59+
) -> Dict[str, Any]:
5960
"""Calculate statistics."""
60-
stats = {}
61+
stats: Dict[str, Any] = {}
6162
for recipe in df["recipe"].unique():
6263
stats[recipe] = {}
6364
for metric in metrics:
@@ -76,7 +77,7 @@ def calculate_statistics(self, df: pd.DataFrame, metrics: list):
7677
}
7778
return stats
7879

79-
def output_box_plots_by_dataset(self, df: DataFrame, metrics: List[str]):
80+
def output_box_plots_by_dataset(self, df: pd.DataFrame, metrics: List[str]) -> None:
8081
"""Output box plots by dataset."""
8182
stats = self.calculate_statistics(df, metrics)
8283
recipes = sorted(df["recipe"].unique(), key=lambda x: x.lower())
@@ -101,7 +102,6 @@ def output_box_plots_by_dataset(self, df: DataFrame, metrics: List[str]):
101102
q1 = []
102103
median = []
103104
q3 = []
104-
mean = []
105105
low = []
106106
high = []
107107
for metric in metrics:
@@ -120,7 +120,7 @@ def output_box_plots_by_dataset(self, df: DataFrame, metrics: List[str]):
120120
q1=q1,
121121
median=median,
122122
q3=q3,
123-
mean=mean,
123+
mean=[],
124124
lowerfence=low,
125125
upperfence=high,
126126
name=recipe,
@@ -159,7 +159,9 @@ def output_box_plots_by_dataset(self, df: DataFrame, metrics: List[str]):
159159

160160
write_image(fig, f"./{dataset}_box_plot.png")
161161

162-
def output_histograms_by_dataset(self, df: pd.DataFrame, metrics: List[str]):
162+
def output_histograms_by_dataset(
163+
self, df: pd.DataFrame, metrics: List[str]
164+
) -> None:
163165
"""Output histograms by dataset."""
164166
# Append "latency" to the metrics list
165167
metrics.append("latency")
@@ -184,7 +186,7 @@ def output_histograms_by_dataset(self, df: pd.DataFrame, metrics: List[str]):
184186
sns.set_theme(style="darkgrid")
185187

186188
# Custom function to set bin ranges and filter invalid values
187-
def custom_hist(data, **kws):
189+
def custom_hist(data: Dict[str, Any], **kws: Any) -> None:
188190
metric = data["metric"].iloc[0]
189191
data = data[
190192
np.isfinite(data["value"])
@@ -251,12 +253,12 @@ def custom_hist(data, **kws):
251253
# Close the plot to avoid displaying it
252254
plt.close()
253255

254-
def compare(self, recipes: List[str], output: str):
256+
def compare(self, recipes: List[str], output: str = "box-plots") -> None:
255257
"""Compare results from 2 (or more) recipes."""
256258
df, metrics = self.get_all_data(recipes=recipes)
257259
if output == "box-plots":
258260
self.output_box_plots_by_dataset(df=df, metrics=metrics)
259261
elif output == "histogram-grid":
260262
self.output_histograms_by_dataset(df=df, metrics=metrics)
261263
else:
262-
raise ValueError
264+
raise ValueError(f"Invalid output type: {output}")

libs/ragulate/ragstack_ragulate/cli_commands/compare.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import List, Optional
1+
from argparse import ArgumentParser, _SubParsersAction
2+
from typing import Any, List
23

34
from ragstack_ragulate.analysis import Analysis
45

56
from .utils import remove_sqlite_extension
67

78

8-
def setup_compare(subparsers):
9+
def setup_compare(subparsers: _SubParsersAction[ArgumentParser]) -> None:
910
"""Setup the compare command."""
1011
compare_parser = subparsers.add_parser(
1112
"compare", help="Compare results from 2 (or more) recipes"
@@ -30,9 +31,9 @@ def setup_compare(subparsers):
3031

3132
def call_compare(
3233
recipe: List[str],
33-
output: Optional[str] = "box-plots",
34-
**_,
35-
):
34+
output: str = "box-plots",
35+
**_: Any,
36+
) -> None:
3637
"""Compare results from 2 (or more) recipes."""
3738
analysis = Analysis()
3839

libs/ragulate/ragstack_ragulate/cli_commands/dashboard.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
from argparse import ArgumentParser, _SubParsersAction
2+
from typing import Any
3+
14
from ragstack_ragulate.dashboard import run_dashboard
25

36
from .utils import remove_sqlite_extension
47

58

6-
def setup_dashboard(subparsers):
9+
def setup_dashboard(subparsers: _SubParsersAction[ArgumentParser]) -> None:
710
"""Setup the dashboard command."""
811
dashboard_parser = subparsers.add_parser(
912
"dashboard",
@@ -29,8 +32,8 @@ def setup_dashboard(subparsers):
2932
def call_dashboard(
3033
recipe: str,
3134
port: int,
32-
**_,
33-
):
35+
**_: Any,
36+
) -> None:
3437
"""Runs the TruLens dashboard."""
3538
recipe_name = remove_sqlite_extension(recipe)
3639
run_dashboard(recipe_name=recipe_name, port=port)

libs/ragulate/ragstack_ragulate/cli_commands/download.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from argparse import ArgumentParser, _SubParsersAction
2+
from typing import Any
3+
14
from ragstack_ragulate.datasets import get_dataset
25

36

4-
def setup_download(subparsers):
7+
def setup_download(subparsers: _SubParsersAction[ArgumentParser]) -> None:
58
"""Setup the download command."""
69
download_parser = subparsers.add_parser("download", help="Download a dataset")
710
download_parser.add_argument(
@@ -22,7 +25,7 @@ def setup_download(subparsers):
2225
download_parser.set_defaults(func=lambda args: call_download(**vars(args)))
2326

2427

25-
def call_download(dataset_name: str, kind: str, **_):
28+
def call_download(dataset_name: str, kind: str, **_: Any) -> None:
2629
"""Download a dataset."""
2730
dataset = get_dataset(name=dataset_name, kind=kind)
2831
dataset.download_dataset()

libs/ragulate/ragstack_ragulate/cli_commands/ingest.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import List
1+
from argparse import ArgumentParser, _SubParsersAction
2+
from typing import Any, List
23

34
from ragstack_ragulate.datasets import find_dataset
45
from ragstack_ragulate.pipelines import IngestPipeline
56
from ragstack_ragulate.utils import convert_vars_to_ingredients
67

78

8-
def setup_ingest(subparsers):
9+
def setup_ingest(subparsers: _SubParsersAction[ArgumentParser]) -> None:
910
"""Setup the ingest command."""
1011
ingest_parser = subparsers.add_parser("ingest", help="Run an ingest pipeline")
1112
ingest_parser.add_argument(
@@ -58,8 +59,8 @@ def call_ingest(
5859
var_name: List[str],
5960
var_value: List[str],
6061
dataset: List[str],
61-
**_,
62-
):
62+
**_: Any,
63+
) -> None:
6364
"""Run an ingest pipeline."""
6465
datasets = [find_dataset(name=name) for name in dataset]
6566

libs/ragulate/ragstack_ragulate/cli_commands/query.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import List
1+
from argparse import ArgumentParser, _SubParsersAction
2+
from typing import Any, List
23

34
from ragstack_ragulate.datasets import find_dataset
45
from ragstack_ragulate.pipelines import QueryPipeline
56
from ragstack_ragulate.utils import convert_vars_to_ingredients
67

78

8-
def setup_query(subparsers):
9+
def setup_query(subparsers: _SubParsersAction[ArgumentParser]) -> None:
910
"""Setup the query command."""
1011
query_parser = subparsers.add_parser("query", help="Run a query pipeline")
1112
query_parser.add_argument(
@@ -110,8 +111,8 @@ def call_query(
110111
restart: bool,
111112
provider: str,
112113
model: str,
113-
**_,
114-
):
114+
**_: Any,
115+
) -> None:
115116
"""Run a query pipeline."""
116117
if sample <= 0.0 or sample > 1.0:
117118
raise ValueError("Sample percent must be between 0 and 1")

0 commit comments

Comments
 (0)