Skip to content

Commit 303bbca

Browse files
authored
fix(types): fix pyright type issues with latest pyright version (#366)
1 parent 0430e8f commit 303bbca

File tree

13 files changed

+74
-43
lines changed

13 files changed

+74
-43
lines changed

requirements/dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ rich
22
ruff
33
isort
44
black[jupyter]
5-
pyright==1.1.338
5+
pyright
66
llama_index
77
notebook
88
sphinx-autobuild

src/ragas/llms/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class RagasLLM(ABC):
2525

2626
@property
2727
@abstractmethod
28-
def llm(self):
28+
def llm(self) -> t.Any:
2929
...
3030

3131
def validate_api_key(self):
@@ -39,15 +39,15 @@ def generate(
3939
self,
4040
prompts: list[ChatPromptTemplate],
4141
n: int = 1,
42-
temperature: float = 0,
42+
temperature: float = 1e-8,
4343
callbacks: t.Optional[Callbacks] = None,
4444
) -> LLMResult:
4545
...
4646

4747
@abstractmethod
4848
async def agenerate(
4949
self,
50-
prompts: ChatPromptTemplate,
50+
prompt: ChatPromptTemplate,
5151
n: int = 1,
5252
temperature: float = 1e-8,
5353
callbacks: t.Optional[Callbacks] = None,

src/ragas/llms/langchain.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self, llm: BaseLLM | BaseChatModel):
7777
self.langchain_llm = llm
7878

7979
@property
80-
def llm(self):
80+
def llm(self) -> BaseLLM | BaseChatModel:
8181
return self.langchain_llm
8282

8383
def validate_api_key(self):
@@ -140,6 +140,7 @@ async def agenerate(
140140
self,
141141
prompt: ChatPromptTemplate,
142142
n: int = 1,
143+
temperature: float = 1e-8,
143144
callbacks: t.Optional[Callbacks] = None,
144145
) -> LLMResult:
145146
temperature = 0.2 if n > 1 else 0

src/ragas/llms/openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def __init__(self, model: str, _api_key_env_var: str, timeout: int = 60) -> None
109109
self._client: AsyncClient
110110

111111
@abstractmethod
112-
def _client_init(self) -> AsyncClient:
112+
def _client_init(self):
113113
...
114114

115115
@property

src/ragas/metrics/_answer_correctness.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1313
from ragas.utils import load_as_json
1414

15+
if t.TYPE_CHECKING:
16+
from langchain.callbacks.base import Callbacks
17+
1518
CORRECTNESS_PROMPT = HumanMessagePromptTemplate.from_template(
1619
"""
1720
Extract following from given question and ground truth
@@ -70,8 +73,8 @@ class AnswerCorrectness(MetricWithLLM):
7073
The faithfulness object
7174
"""
7275

73-
name: str = "answer_correctness"
74-
evaluation_mode: EvaluationMode = EvaluationMode.qga
76+
name: str = "answer_correctness" # type: ignore[reportIncompatibleMethodOverride]
77+
evaluation_mode: EvaluationMode = EvaluationMode.qga # type: ignore[reportIncompatibleMethodOverride]
7578
batch_size: int = 15
7679
weights: list[float] = field(default_factory=lambda: [0.75, 0.25])
7780
answer_similarity: AnswerSimilarity | None = None
@@ -85,7 +88,7 @@ def __post_init__(self: t.Self):
8588
def _score_batch(
8689
self: t.Self,
8790
dataset: Dataset,
88-
callbacks: t.Optional[CallbackManager] = None,
91+
callbacks: t.Optional[Callbacks] = None,
8992
callback_group_name: str = "batch",
9093
) -> list[float]:
9194
question, answer, ground_truths = (
@@ -95,8 +98,9 @@ def _score_batch(
9598
)
9699
prompts = []
97100

101+
cb = CallbackManager.configure(inheritable_callbacks=callbacks)
98102
with trace_as_chain_group(
99-
callback_group_name, callback_manager=callbacks
103+
callback_group_name, callback_manager=cb
100104
) as batch_group:
101105
for q, a, g in zip(question, answer, ground_truths):
102106
human_prompt = CORRECTNESS_PROMPT.format(

src/ragas/metrics/_answer_relevance.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
from datasets import Dataset
8-
from langchain.callbacks.manager import trace_as_chain_group
8+
from langchain.callbacks.manager import CallbackManager, trace_as_chain_group
99
from langchain.embeddings import OpenAIEmbeddings
1010
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
1111

@@ -15,7 +15,7 @@
1515
from ragas.utils import load_as_json
1616

1717
if t.TYPE_CHECKING:
18-
from langchain.callbacks.manager import CallbackManager
18+
from langchain.callbacks.base import Callbacks
1919

2020
from ragas.embeddings.base import RagasEmbeddings
2121

@@ -86,8 +86,8 @@ class AnswerRelevancy(MetricWithLLM):
8686
E.g. HuggingFaceEmbeddings('BAAI/bge-base-en')
8787
"""
8888

89-
name: str = "answer_relevancy"
90-
evaluation_mode: EvaluationMode = EvaluationMode.qac
89+
name: str = "answer_relevancy" # type: ignore
90+
evaluation_mode: EvaluationMode = EvaluationMode.qac # type: ignore
9191
batch_size: int = 15
9292
strictness: int = 3
9393
embeddings: RagasEmbeddings = field(default_factory=embedding_factory)
@@ -102,16 +102,18 @@ def init_model(self):
102102
def _score_batch(
103103
self: t.Self,
104104
dataset: Dataset,
105-
callbacks: t.Optional[CallbackManager] = None,
105+
callbacks: t.Optional[Callbacks] = None,
106106
callback_group_name: str = "batch",
107107
) -> list[float]:
108108
questions, answers, contexts = (
109109
dataset["question"],
110110
dataset["answer"],
111111
dataset["contexts"],
112112
)
113+
114+
cb = CallbackManager.configure(inheritable_callbacks=callbacks)
113115
with trace_as_chain_group(
114-
callback_group_name, callback_manager=callbacks
116+
callback_group_name, callback_manager=cb
115117
) as batch_group:
116118
prompts = []
117119
for ans, ctx in zip(answers, contexts):

src/ragas/metrics/_answer_similarity.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1616

1717
if t.TYPE_CHECKING:
18-
from langchain.callbacks.manager import CallbackManager
18+
from langchain.callbacks.base import Callbacks
1919

2020
from ragas.embeddings.base import RagasEmbeddings
2121

@@ -42,8 +42,8 @@ class AnswerSimilarity(MetricWithLLM):
4242
Default 0.5
4343
"""
4444

45-
name: str = "answer_similarity"
46-
evaluation_mode: EvaluationMode = EvaluationMode.ga
45+
name: str = "answer_similarity" # type: ignore
46+
evaluation_mode: EvaluationMode = EvaluationMode.ga # type: ignore
4747
batch_size: int = 15
4848
embeddings: RagasEmbeddings = field(default_factory=embedding_factory)
4949
is_cross_encoder: bool = False
@@ -67,7 +67,7 @@ def init_model(self):
6767
def _score_batch(
6868
self: t.Self,
6969
dataset: Dataset,
70-
callbacks: t.Optional[CallbackManager] = None,
70+
callbacks: t.Optional[Callbacks] = None,
7171
callback_group_name: str = "batch",
7272
) -> list[float]:
7373
ground_truths, answers = dataset["ground_truths"], dataset["answer"]

src/ragas/metrics/_context_precision.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1212
from ragas.utils import load_as_json
1313

14+
if t.TYPE_CHECKING:
15+
from langchain.callbacks.base import Callbacks
16+
1417
CONTEXT_PRECISION = HumanMessagePromptTemplate.from_template(
1518
"""\
1619
Verify if the information in the given context is useful in answering the question.
@@ -47,20 +50,22 @@ class ContextPrecision(MetricWithLLM):
4750
Batch size for openai completion.
4851
"""
4952

50-
name: str = "context_precision"
51-
evaluation_mode: EvaluationMode = EvaluationMode.qc
53+
name: str = "context_precision" # type: ignore
54+
evaluation_mode: EvaluationMode = EvaluationMode.qc # type: ignore
5255
batch_size: int = 15
5356

5457
def _score_batch(
5558
self: t.Self,
5659
dataset: Dataset,
57-
callbacks: t.Optional[CallbackManager] = None,
60+
callbacks: t.Optional[Callbacks] = None,
5861
callback_group_name: str = "batch",
5962
) -> list:
6063
prompts = []
6164
questions, contexts = dataset["question"], dataset["contexts"]
65+
66+
cb = CallbackManager.configure(inheritable_callbacks=callbacks)
6267
with trace_as_chain_group(
63-
callback_group_name, callback_manager=callbacks
68+
callback_group_name, callback_manager=cb
6469
) as batch_group:
6570
for qstn, ctx in zip(questions, contexts):
6671
human_prompts = [

src/ragas/metrics/_context_recall.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1212
from ragas.utils import load_as_json
1313

14+
if t.TYPE_CHECKING:
15+
from langchain.callbacks.base import Callbacks
16+
1417
CONTEXT_RECALL_RA = HumanMessagePromptTemplate.from_template(
1518
"""
1619
Given a context, and an answer, analyze each sentence in the answer and classify if the sentence can be attributed to the given context or not. Output json with reason.
@@ -77,14 +80,14 @@ class ContextRecall(MetricWithLLM):
7780
Batch size for openai completion.
7881
"""
7982

80-
name: str = "context_recall"
81-
evaluation_mode: EvaluationMode = EvaluationMode.qcg
83+
name: str = "context_recall" # type: ignore
84+
evaluation_mode: EvaluationMode = EvaluationMode.qcg # type: ignore
8285
batch_size: int = 15
8386

8487
def _score_batch(
8588
self: t.Self,
8689
dataset: Dataset,
87-
callbacks: t.Optional[CallbackManager] = None,
90+
callbacks: t.Optional[Callbacks] = None,
8891
callback_group_name: str = "batch",
8992
) -> list:
9093
prompts = []
@@ -94,8 +97,9 @@ def _score_batch(
9497
dataset["contexts"],
9598
)
9699

100+
cb = CallbackManager.configure(inheritable_callbacks=callbacks)
97101
with trace_as_chain_group(
98-
callback_group_name, callback_manager=callbacks
102+
callback_group_name, callback_manager=cb
99103
) as batch_group:
100104
for qstn, gt, ctx in zip(question, ground_truths, contexts):
101105
gt = "\n".join(gt) if isinstance(gt, list) else gt

src/ragas/metrics/_context_relevancy.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414
from ragas.metrics.base import EvaluationMode, MetricWithLLM
1515

16+
if t.TYPE_CHECKING:
17+
from langchain.callbacks.base import Callbacks
18+
1619
CONTEXT_RELEVANCE = HumanMessagePromptTemplate.from_template(
1720
"""\
1821
Please extract relevant sentences from the provided context that is absolutely required answer the following question. If no relevant sentences are found, or if you believe the question cannot be answered from the given context, return the phrase "Insufficient Information". While extracting candidate sentences you're not allowed to make any changes to sentences from given context.
@@ -47,8 +50,8 @@ class ContextRelevancy(MetricWithLLM):
4750
Batch size for openai completion.
4851
"""
4952

50-
name: str = "context_relevancy"
51-
evaluation_mode: EvaluationMode = EvaluationMode.qc
53+
name: str = "context_relevancy" # type: ignore
54+
evaluation_mode: EvaluationMode = EvaluationMode.qc # type: ignore
5255
batch_size: int = 15
5356
show_deprecation_warning: bool = False
5457

@@ -58,7 +61,7 @@ def __post_init__(self: t.Self):
5861
def _score_batch(
5962
self: t.Self,
6063
dataset: Dataset,
61-
callbacks: t.Optional[CallbackManager] = None,
64+
callbacks: t.Optional[Callbacks] = None,
6265
callback_group_name: str = "batch",
6366
) -> list[float]:
6467
if self.show_deprecation_warning:
@@ -67,8 +70,10 @@ def _score_batch(
6770
)
6871
prompts = []
6972
questions, contexts = dataset["question"], dataset["contexts"]
73+
74+
cb = CallbackManager.configure(inheritable_callbacks=callbacks)
7075
with trace_as_chain_group(
71-
callback_group_name, callback_manager=callbacks
76+
callback_group_name, callback_manager=cb
7277
) as batch_group:
7378
for q, c in zip(questions, contexts):
7479
human_prompt = CONTEXT_RELEVANCE.format(

0 commit comments

Comments
 (0)