Skip to content

Commit 5d70348

Browse files
committed
Prompt class + implementation in Context Precision
1 parent 1ffc6d5 commit 5d70348

File tree

4 files changed

+240
-126
lines changed

4 files changed

+240
-126
lines changed

src/ragas/metrics/collections/_context_precision.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,19 @@
44
from typing import List
55

66
import numpy as np
7-
from pydantic import BaseModel
87

98
from ragas.metrics.collections.base import BaseMetric
109
from ragas.metrics.result import MetricResult
1110
from ragas.prompt.metrics.context_precision import (
12-
context_precision_with_reference_prompt,
13-
context_precision_without_reference_prompt,
11+
ContextPrecisionInput,
12+
ContextPrecisionOutput,
13+
ContextPrecisionPrompt,
1414
)
1515

1616
if t.TYPE_CHECKING:
1717
from ragas.llms.base import InstructorBaseRagasLLM
1818

1919

20-
class ContextPrecisionOutput(BaseModel):
21-
"""Structured output for context precision evaluation."""
22-
23-
reason: str
24-
verdict: int
25-
26-
2720
class ContextPrecisionWithReference(BaseMetric):
2821
"""
2922
Modern v2 implementation of context precision with reference.
@@ -79,6 +72,7 @@ def __init__(
7972
"""
8073
# Set attributes explicitly before calling super()
8174
self.llm = llm
75+
self.prompt = ContextPrecisionPrompt() # Initialize prompt class once
8276

8377
# Call super() for validation (without passing llm in kwargs)
8478
super().__init__(name=name, **kwargs)
@@ -108,10 +102,12 @@ async def ascore(
108102
# Evaluate each retrieved context
109103
verdicts = []
110104
for context in retrieved_contexts:
111-
prompt = context_precision_with_reference_prompt(
112-
user_input, context, reference
105+
# Create input data and generate prompt
106+
input_data = ContextPrecisionInput(
107+
question=user_input, context=context, answer=reference
113108
)
114-
result = await self.llm.agenerate(prompt, ContextPrecisionOutput)
109+
prompt_string = self.prompt.to_string(input_data)
110+
result = await self.llm.agenerate(prompt_string, ContextPrecisionOutput)
115111
verdicts.append(result.verdict)
116112

117113
# Calculate average precision
@@ -196,6 +192,7 @@ def __init__(
196192
"""
197193
# Set attributes explicitly before calling super()
198194
self.llm = llm
195+
self.prompt = ContextPrecisionPrompt() # Initialize prompt class once
199196

200197
# Call super() for validation (without passing llm in kwargs)
201198
super().__init__(name=name, **kwargs)
@@ -225,10 +222,12 @@ async def ascore(
225222
# Evaluate each retrieved context
226223
verdicts = []
227224
for context in retrieved_contexts:
228-
prompt = context_precision_without_reference_prompt(
229-
user_input, context, response
225+
# Create input data and generate prompt
226+
input_data = ContextPrecisionInput(
227+
question=user_input, context=context, answer=response
230228
)
231-
result = await self.llm.agenerate(prompt, ContextPrecisionOutput)
229+
prompt_string = self.prompt.to_string(input_data)
230+
result = await self.llm.agenerate(prompt_string, ContextPrecisionOutput)
232231
verdicts.append(result.verdict)
233232

234233
# Calculate average precision

src/ragas/prompt/metrics/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
from ragas.prompt.metrics.answer_relevance import answer_relevancy_prompt
55
from ragas.prompt.metrics.common import nli_statement_prompt, statement_generator_prompt
66
from ragas.prompt.metrics.context_precision import (
7-
context_precision_prompt,
8-
context_precision_with_reference_prompt,
9-
context_precision_without_reference_prompt,
7+
ContextPrecisionInput,
8+
ContextPrecisionOutput,
9+
ContextPrecisionPrompt,
1010
)
1111

1212
__all__ = [
1313
"answer_relevancy_prompt",
14-
"context_precision_prompt",
15-
"context_precision_with_reference_prompt",
16-
"context_precision_without_reference_prompt",
14+
"ContextPrecisionPrompt",
15+
"ContextPrecisionInput",
16+
"ContextPrecisionOutput",
1717
"correctness_classifier_prompt",
1818
"nli_statement_prompt",
1919
"statement_generator_prompt",
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""Base prompt class for metrics with structured input/output models."""
2+
3+
import json
4+
import typing as t
5+
from abc import ABC
6+
7+
from pydantic import BaseModel
8+
9+
# Type variables for generics
10+
InputModel = t.TypeVar("InputModel", bound=BaseModel)
11+
OutputModel = t.TypeVar("OutputModel", bound=BaseModel)
12+
13+
14+
class BasePrompt(ABC, t.Generic[InputModel, OutputModel]):
15+
"""
16+
Base class for structured prompts with type-safe input/output models.
17+
18+
Attributes:
19+
input_model: Pydantic model class for input validation
20+
output_model: Pydantic model class for output schema generation
21+
instruction: Task description for the LLM
22+
examples: List of (input, output) example pairs for few-shot learning
23+
language: Language for the prompt (default: "english")
24+
"""
25+
26+
# Must be set by subclasses
27+
input_model: t.Type[InputModel]
28+
output_model: t.Type[OutputModel]
29+
instruction: str
30+
examples: t.List[t.Tuple[InputModel, OutputModel]]
31+
language: str = "english"
32+
33+
def to_string(self, data: InputModel) -> str:
34+
"""
35+
Convert prompt with input data to complete prompt string for LLM.
36+
37+
Args:
38+
data: Input data instance (validated by input_model)
39+
40+
Returns:
41+
Complete prompt string ready for LLM
42+
"""
43+
# Generate JSON schema for output
44+
output_schema = json.dumps(self.output_model.model_json_schema())
45+
46+
# Generate examples section
47+
examples_str = self._generate_examples()
48+
49+
# Convert input data to JSON
50+
input_json = data.model_dump_json(indent=4, exclude_none=True)
51+
52+
# Build complete prompt (matches existing function format)
53+
return f"""{self.instruction}
54+
Please return the output in a JSON format that complies with the following schema as specified in JSON Schema:
55+
{output_schema}Do not use single quotes in your response but double quotes,properly escaped with a backslash.
56+
57+
{examples_str}
58+
-----------------------------
59+
60+
Now perform the same with the following input
61+
input: {input_json}
62+
Output: """
63+
64+
def _generate_examples(self) -> str:
65+
"""
66+
Generate examples section of the prompt.
67+
68+
Returns:
69+
Formatted examples string or empty string if no examples
70+
"""
71+
if not self.examples:
72+
return ""
73+
74+
example_strings = []
75+
for idx, (input_data, output_data) in enumerate(self.examples):
76+
example_strings.append(
77+
f"Example {idx + 1}\n"
78+
f"Input: {input_data.model_dump_json(indent=4)}\n"
79+
f"Output: {output_data.model_dump_json(indent=4)}"
80+
)
81+
82+
return "--------EXAMPLES-----------\n" + "\n\n".join(example_strings)
83+
84+
async def adapt(
85+
self,
86+
target_language: str,
87+
llm,
88+
adapt_instruction: bool = False,
89+
) -> "BasePrompt[InputModel, OutputModel]":
90+
"""
91+
Adapt the prompt to a new language using minimal translation.
92+
93+
Args:
94+
target_language: Target language (e.g., "spanish", "french")
95+
llm: LLM instance for translation
96+
adapt_instruction: Whether to adapt instruction text (default: False)
97+
98+
Returns:
99+
New prompt instance adapted to the target language
100+
"""
101+
import copy
102+
103+
# Create adapted prompt
104+
new_prompt = copy.deepcopy(self)
105+
new_prompt.language = target_language
106+
107+
# Translate instruction if requested
108+
if adapt_instruction:
109+
instruction_prompt = f"Translate this to {target_language}, keep technical terms: {self.instruction}"
110+
try:
111+
response = await llm.agenerate(instruction_prompt)
112+
new_prompt.instruction = str(response).strip()
113+
except Exception:
114+
# Keep original if translation fails
115+
pass
116+
117+
# Translate examples (simplified approach)
118+
translated_examples = []
119+
for input_ex, output_ex in self.examples:
120+
try:
121+
# Simple per-example translation
122+
example_prompt = f"""Translate this example to {target_language}, keep the same structure:
123+
124+
Input: {input_ex.model_dump_json()}
125+
Output: {output_ex.model_dump_json()}
126+
127+
Return as: Input: {{translated_input_json}} Output: {{translated_output_json}}"""
128+
129+
response = await llm.agenerate(example_prompt)
130+
131+
# Try to extract translated JSON (basic parsing)
132+
response_str = str(response)
133+
if "Input:" in response_str and "Output:" in response_str:
134+
parts = response_str.split("Output:")
135+
input_part = parts[0].replace("Input:", "").strip()
136+
output_part = parts[1].strip()
137+
138+
translated_input = self.input_model.model_validate_json(input_part)
139+
translated_output = self.output_model.model_validate_json(
140+
output_part
141+
)
142+
translated_examples.append((translated_input, translated_output))
143+
else:
144+
# Fallback to original
145+
translated_examples.append((input_ex, output_ex))
146+
147+
except Exception:
148+
# Fallback to original example if translation fails
149+
translated_examples.append((input_ex, output_ex))
150+
151+
new_prompt.examples = translated_examples
152+
return new_prompt

0 commit comments

Comments
 (0)