Skip to content

Commit 7c07569

Browse files
committed
Fix llm_context_recall
1 parent 09e56d0 commit 7c07569

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

rag_experiment_accelerator/evaluation/llm_based_metrics.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,25 +128,26 @@ def llm_context_recall(
128128
double: The context recall score generated between the ground truth (expected) and context.
129129
"""
130130
context = "\n".join(retrieved_contexts)
131-
prompt = (
132-
"\nquestion: "
133-
+ question
134-
+ "\ncontext: "
135-
+ context
136-
+ "\nanswer: "
137-
+ groundtruth_answer
138-
)
139-
result = response_generator.generate_response(
140-
sys_message=llm_context_recall_instruction,
141-
prompt=prompt,
131+
132+
result: list | None = response_generator.generate_response(
133+
llm_context_recall_instruction,
134+
context=context,
135+
question=question,
136+
answer=groundtruth_answer,
142137
)
143-
good_response = '"Attributed": "1"'
144-
bad_response = '"Attributed": "0"'
145138

146-
return (
147-
result.count(good_response)
148-
/ (result.count(good_response) + result.count(bad_response))
149-
) * 100
139+
good_responses = 0
140+
141+
for response in result:
142+
try:
143+
score = response.get("attributed", 0)
144+
good_responses += int(score)
145+
except ValueError:
146+
logger.warning(f"Unable to parse {score} as int.")
147+
if not result:
148+
return -1
149+
else:
150+
return (good_responses / len(result)) * 100
150151

151152

152153
def compute_llm_based_score(

rag_experiment_accelerator/llm/prompt/ragas_prompts.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@ def validate_context_recall(text: str) -> bool:
1616

1717
def is_valid_entry(entry):
1818
statement_key_pattern = re.compile(r"^statement_\d+$")
19-
for key in entry.keys():
20-
if key not in ["reason", "attributed"] or not statement_key_pattern.match(
21-
key
22-
):
23-
return False
19+
return all(
20+
key in ["reason", "attributed"] or statement_key_pattern.match(key)
21+
for key in entry.keys()
22+
)
2423

2524
return isinstance(json_text, list) and all(
2625
is_valid_entry(entry) for entry in json_text

0 commit comments

Comments
 (0)