Skip to content

Commit 5476079

Browse files
authored
fix: handle non dict cases (#564)
fixes: #555
1 parent a2a177d commit 5476079

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

src/ragas/metrics/_answer_correctness.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def _compute_statement_presence(self, prediction: t.Any) -> float:
119119
}
120120

121121
prediction = prediction if isinstance(prediction, list) else [prediction]
122+
prediction = [item if isinstance(item, dict) else {} for item in prediction]
122123
if prediction:
123124
prediction = [
124125
item.get(key_map[k], np.nan)

src/ragas/metrics/_context_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _calculate_average_precision(self, json_responses: t.List[t.Dict]) -> float:
9696
item if isinstance(item, dict) else {} for item in json_responses
9797
]
9898
verdict_list = [
99-
int("1" == resp.get("verdict", "0").strip())
99+
int("1" == resp.get("verdict", "").strip())
100100
if resp.get("verdict")
101101
else np.nan
102102
for resp in json_responses

src/ragas/metrics/_context_recall.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,24 @@ def _create_context_recall_prompt(self, row: t.Dict) -> PromptValue:
9090
return self.context_recall_prompt.format(question=qstn, context=ctx, answer=gt)
9191

9292
def _compute_score(self, response: t.Any) -> float:
93-
if response:
94-
response = [
95-
int(item.get("Attributed", "0").strip() == "1")
96-
if item.get("Attributed")
97-
else np.nan
98-
for item in response
99-
]
100-
denom = len(response)
101-
numerator = sum(response)
102-
return numerator / denom
103-
else:
93+
response = response if isinstance(response, list) else [response]
94+
response = [item if isinstance(item, dict) else {} for item in response]
95+
response = [
96+
int(item.get("Attributed").strip() == "1")
97+
if item.get("Attributed")
98+
else np.nan
99+
for item in response
100+
]
101+
denom = len(response)
102+
numerator = sum(response)
103+
score = numerator / denom
104+
105+
if np.isnan(score):
104106
logger.warning(
105107
"Invalid JSON response. Expected dictionary with key 'Attributed'"
106108
)
107-
return np.nan
109+
110+
return score
108111

109112
async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
110113
assert self.llm is not None, "set LLM before use"

0 commit comments

Comments
 (0)