Skip to content

Commit 41e9e54

Browse files
authored
fix: handle edge cases in prompt processing (#374)
1 parent b455475 commit 41e9e54

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

src/ragas/metrics/_answer_correctness.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,21 @@ def _score_batch(
119119
f1_score = []
120120
for prediction in outputs:
121121
prediction = json_loader.safe_load(prediction[0].text, self.llm)
122-
prediction = [
123-
item.get(key_map[k], np.nan)
124-
for item in prediction
125-
for k in key_map.keys()
126-
]
127-
tp, fp, fn = [
128-
len(item) if isinstance(item, list) else np.nan for item in prediction
129-
]
130-
score = tp / (tp + 0.5 * (fp + fn))
122+
prediction = prediction if isinstance(prediction, list) else []
123+
if prediction:
124+
prediction = [
125+
item.get(key_map[k], np.nan)
126+
for item in prediction
127+
for k in key_map.keys()
128+
]
129+
tp, fp, fn = [
130+
len(item) if isinstance(item, list) else np.nan
131+
for item in prediction
132+
]
133+
score = tp / (tp + 0.5 * (fp + fn))
134+
else:
135+
score = np.nan
136+
131137
f1_score.append(score)
132138

133139
similarity_scores = self.answer_similarity._score_batch(dataset) # type: ignore

src/ragas/metrics/_faithfulness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _score_batch(
173173
scores = []
174174
for output in outputs:
175175
output = json_loader.safe_load(output[0].text, self.llm)
176-
output = output if output else []
176+
output = output if isinstance(output, list) else []
177177
faithful_statements = sum(
178178
verdict_score_map.get(dict.get("verdict", "").lower(), np.nan)
179179
for dict in output

src/ragas/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _fix_to_json(
111111
callbacks: t.Optional[CallbackManager] = None,
112112
callback_group_name: str = "batch",
113113
):
114-
# TODO (executor)
114+
# TODO (executor)
115115
with trace_as_chain_group(
116116
callback_group_name, callback_manager=callbacks
117117
) as batch_group:

0 commit comments

Comments
 (0)