Skip to content

Commit 0cdbd93

Browse files
authored
fix: answer_correctness doesn't reset stuff properly (#562)
1 parent 6a88465 commit 0cdbd93

File tree

4 files changed

+23
-10
lines changed

4 files changed

+23
-10
lines changed

docs/howtos/integrations/athina.ipynb

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,23 @@
4747
"outputs": [],
4848
"source": [
4949
"import os\n",
50-
"from athina.evals import RagasAnswerCorrectness, RagasAnswerRelevancy, RagasContextRelevancy, RagasFaithfulness \n",
50+
"from athina.evals import (\n",
51+
" RagasAnswerCorrectness,\n",
52+
" RagasAnswerRelevancy,\n",
53+
" RagasContextRelevancy,\n",
54+
" RagasFaithfulness,\n",
55+
")\n",
5156
"from athina.loaders import RagasLoader\n",
5257
"from athina.keys import AthinaApiKey, OpenAiApiKey\n",
5358
"from athina.runner.run import EvalRunner\n",
5459
"import pandas as pd\n",
5560
"\n",
5661
"# Set your API keys\n",
57-
"OpenAiApiKey.set_key(os.getenv('OPENAI_API_KEY'))\n",
58-
"AthinaApiKey.set_key(os.getenv('ATHINA_API_KEY'))\n",
62+
"OpenAiApiKey.set_key(os.getenv(\"OPENAI_API_KEY\"))\n",
63+
"AthinaApiKey.set_key(os.getenv(\"ATHINA_API_KEY\"))\n",
5964
"\n",
6065
"# Load your dataset from a dictionary, json, or csv: https://docs.athina.ai/evals/loading_data\n",
61-
"dataset = RagasLoader().load_json('raw_data.json')\n",
66+
"dataset = RagasLoader().load_json(\"raw_data.json\")\n",
6267
"\n",
6368
"# Configure the eval suite\n",
6469
"eval_model = \"gpt-3.5-turbo\"\n",
@@ -73,7 +78,7 @@
7378
"batch_eval_result = EvalRunner.run_suite(\n",
7479
" evals=eval_suite,\n",
7580
" data=dataset,\n",
76-
" max_parallel_evals=1, # If you increase this, you may run into rate limits\n",
81+
" max_parallel_evals=1, # If you increase this, you may run into rate limits\n",
7782
")\n",
7883
"\n",
7984
"pd.DataFrame(batch_eval_result)"

src/ragas/evaluation.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,19 @@
55

66
import numpy as np
77
from datasets import Dataset, concatenate_datasets
8-
from langchain_core.language_models import BaseLanguageModel as LangchainLLM
98
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
9+
from langchain_core.language_models import BaseLanguageModel as LangchainLLM
1010

1111
from ragas._analytics import EvaluationEvent, track
1212
from ragas.callbacks import new_group
1313
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
14+
from ragas.exceptions import ExceptionInRunner
1415
from ragas.executor import Executor
1516
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
17+
from ragas.metrics._answer_correctness import AnswerCorrectness
1618
from ragas.metrics.base import Metric, MetricWithEmbeddings, MetricWithLLM
1719
from ragas.metrics.critique import AspectCritique
1820
from ragas.run_config import RunConfig
19-
from ragas.exceptions import ExceptionInRunner
2021

2122
# from ragas.metrics.critique import AspectCritique
2223
from ragas.validation import (
@@ -158,6 +159,7 @@ def evaluate(
158159
binary_metrics = []
159160
llm_changed: t.List[int] = []
160161
embeddings_changed: t.List[int] = []
162+
answer_correctness_is_set = -1
161163
for i, metric in enumerate(metrics):
162164
if isinstance(metric, AspectCritique):
163165
binary_metrics.append(metric.name)
@@ -169,6 +171,9 @@ def evaluate(
169171
if metric.embeddings is None:
170172
metric.embeddings = embeddings
171173
embeddings_changed.append(i)
174+
if isinstance(metric, AnswerCorrectness):
175+
if metric.answer_similarity is None:
176+
answer_correctness_is_set = i
172177

173178
# initialize all the models in the metrics
174179
[m.init(run_config) for m in metrics]
@@ -237,6 +242,10 @@ def evaluate(
237242
t.cast(MetricWithLLM, metrics[i]).llm = None
238243
for i in embeddings_changed:
239244
t.cast(MetricWithEmbeddings, metrics[i]).embeddings = None
245+
if answer_correctness_is_set != -1:
246+
t.cast(
247+
AnswerCorrectness, metrics[answer_correctness_is_set]
248+
).answer_similarity = None
240249

241250
# log the evaluation event
242251
metrics_names = [m.name for m in metrics]

src/ragas/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import asyncio
44
import logging
5+
import threading
56
import typing as t
67
from dataclasses import dataclass, field
7-
import threading
88

99
import numpy as np
1010
from tqdm.auto import tqdm

src/ragas/llms/prompt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def format(self, **kwargs: t.Any) -> PromptValue:
149149
def adapt(
150150
self, language: str, llm: BaseRagasLLM, cache_dir: t.Optional[str] = None
151151
) -> Prompt:
152-
153152
def get_all_keys(nested_json):
154153
keys = set()
155154
for key, value in nested_json.items():
@@ -160,7 +159,7 @@ def get_all_keys(nested_json):
160159

161160
if self.language == language:
162161
return self
163-
162+
164163
# TODO: Add callbacks
165164
cache_dir = cache_dir if cache_dir else get_cache_dir()
166165
if os.path.exists(os.path.join(cache_dir, language, f"{self.name}.json")):

0 commit comments

Comments
 (0)