Skip to content

Commit e1e05f8

Browse files
authored
fix: edge case in CR (#728)
fixes: #721
1 parent 5b59814 commit e1e05f8

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/ragas/metrics/_context_recall.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def _compute_score(self, response: t.Any) -> float:
108108
]
109109
denom = len(response)
110110
numerator = sum(response)
111-
score = numerator / denom
111+
score = numerator / denom if denom > 0 else np.nan
112112

113113
if np.isnan(score):
114114
logger.warning(

src/ragas/testset/filters.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@
2222
logger = logging.getLogger(__name__)
2323

2424

25+
def patch_save_json(data, filename):
26+
import json
27+
import os
28+
29+
path = f"/Users/shahules/Myprojects/ragas/experiments/{filename}.json"
30+
if os.path.exists(path):
31+
database = json.load(open(path))
32+
database = database if isinstance(database, list) else [database]
33+
database.append(data)
34+
else:
35+
database = [data]
36+
with open(path, "w") as f:
37+
json.dump(database, f, indent=4)
38+
39+
2540
@dataclass
2641
class Filter(ABC):
2742
llm: BaseRagasLLM
@@ -56,6 +71,8 @@ async def filter(self, node: Node) -> t.Dict:
5671
score = await json_loader.safe_load(output, llm=self.llm)
5772
score = score if isinstance(score, dict) else {}
5873
logger.debug("node filter: %s", score)
74+
score.update({"context": node.page_content})
75+
patch_save_json(score, "node_filter")
5976
score.update({"score": score.get("score", 0) >= self.threshold})
6077
return score
6178

@@ -87,6 +104,8 @@ async def filter(self, question: str) -> t.Tuple[bool, str]:
87104
results = results.generations[0][0].text.strip()
88105
json_results = await json_loader.safe_load(results, llm=self.llm)
89106
json_results = json_results if isinstance(json_results, dict) else {}
107+
json_results.update({"question": question})
108+
patch_save_json(json_results, "question_filter")
90109
logger.debug("filtered question: %s", json_results)
91110
return json_results.get("verdict") == "1", json_results.get("feedback", "")
92111

@@ -120,6 +139,8 @@ async def filter(self, simple_question: str, compressed_question: str) -> bool:
120139
results = results.generations[0][0].text.strip()
121140
json_results = await json_loader.safe_load(results, llm=self.llm)
122141
json_results = json_results if isinstance(json_results, dict) else {}
142+
json_results.update({"questions": [simple_question, compressed_question]})
143+
patch_save_json(json_results, "evolution_filter")
123144
logger.debug("evolution filter: %s", json_results)
124145
return json_results.get("verdict") == "1"
125146

0 commit comments

Comments
 (0)