Skip to content

Commit 07655ff

Browse files
authored
refactor: docs and warnings for metric base new structure (#2333)
Follow up on #2320
1 parent 31bf2a8 commit 07655ff

File tree

7 files changed

+83
-43
lines changed

7 files changed

+83
-43
lines changed

src/ragas/metrics/base.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -781,11 +781,29 @@ def create_auto_response_model(name: str, **fields):
781781
Name for the model class
782782
**fields
783783
Field definitions in create_model format
784+
Each field is specified as: field_name=(type, default_or_field_info)
784785
785786
Returns:
786787
--------
787788
Type[BaseModel]
788789
Pydantic model class marked as auto-generated
790+
791+
Examples:
792+
---------
793+
>>> from pydantic import Field
794+
>>> # Simple model with required fields
795+
>>> ResponseModel = create_auto_response_model(
796+
... "ResponseModel",
797+
... value=(str, ...),
798+
... reason=(str, ...)
799+
... )
800+
>>>
801+
>>> # Model with Field validators and descriptions
802+
>>> ResponseModel = create_auto_response_model(
803+
... "ResponseModel",
804+
... value=(str, Field(..., description="The predicted value")),
805+
... reason=(str, Field(..., description="Reasoning for the prediction"))
806+
... )
789807
"""
790808
from pydantic import create_model
791809

@@ -927,23 +945,33 @@ def save(self, path: t.Optional[str] = None) -> None:
927945
elif not file_path.suffix:
928946
file_path = file_path.with_suffix(".json")
929947

948+
# Collect warning messages for data loss
949+
warning_messages = []
950+
930951
if hasattr(self, "_response_model") and self._response_model:
931952
# Only warn for custom response models, not auto-generated ones
932953
if not getattr(self._response_model, "__ragas_auto_generated__", False):
933-
warnings.warn(
934-
"Custom response_model cannot be saved and will be lost. "
935-
"You'll need to set it manually after loading."
954+
warning_messages.append(
955+
"- Custom response_model will be lost (set it manually after loading)"
936956
)
937957

938-
# Serialize the prompt
939-
prompt_data = self._serialize_prompt()
958+
# Serialize the prompt (may add embedding_model warning)
959+
prompt_data = self._serialize_prompt(warning_messages)
940960

941961
# Determine the metric type
942962
metric_type = self.__class__.__name__
943963

944964
# Get metric-specific config
945965
config = self._get_metric_config()
946966

967+
# Emit consolidated warning if there's data loss
968+
if warning_messages:
969+
warnings.warn(
970+
"Some metric components cannot be saved and will be lost:\n"
971+
+ "\n".join(warning_messages)
972+
+ "\n\nYou'll need to provide these when loading the metric."
973+
)
974+
947975
data = {
948976
"format_version": "1.0",
949977
"metric_type": metric_type,
@@ -962,22 +990,17 @@ def save(self, path: t.Optional[str] = None) -> None:
962990
except (OSError, IOError) as e:
963991
raise ValueError(f"Cannot save metric to {file_path}: {e}")
964992

965-
def _serialize_prompt(self) -> t.Dict[str, t.Any]:
993+
def _serialize_prompt(self, warning_messages: t.List[str]) -> t.Dict[str, t.Any]:
966994
"""Serialize the prompt for storage."""
967995
from ragas.prompt.dynamic_few_shot import DynamicFewShotPrompt
968996
from ragas.prompt.simple_prompt import Prompt
969997

970998
if isinstance(self.prompt, str):
971999
return {"type": "string", "instruction": self.prompt}
9721000
elif isinstance(self.prompt, DynamicFewShotPrompt):
973-
# Warn about embedding model
9741001
if self.prompt.example_store.embedding_model:
975-
import warnings
976-
977-
warnings.warn(
978-
"embedding_model cannot be saved and will be lost. "
979-
"You'll need to provide it when loading using: "
980-
"load(path, embedding_model=YourModel)"
1002+
warning_messages.append(
1003+
"- embedding_model will be lost (provide it when loading: load(path, embedding_model=YourModel))"
9811004
)
9821005

9831006
return {
@@ -1171,13 +1194,26 @@ def _deserialize_prompt(
11711194
prompt_type = prompt_data.get("type")
11721195

11731196
if prompt_type == "string":
1197+
if "instruction" not in prompt_data:
1198+
raise ValueError(
1199+
"Prompt data missing required 'instruction' field for string prompt"
1200+
)
11741201
return prompt_data["instruction"]
11751202
elif prompt_type == "Prompt":
1203+
if "instruction" not in prompt_data:
1204+
raise ValueError(
1205+
"Prompt data missing required 'instruction' field for Prompt"
1206+
)
11761207
examples = [
11771208
(ex["input"], ex["output"]) for ex in prompt_data.get("examples", [])
11781209
]
11791210
return Prompt(instruction=prompt_data["instruction"], examples=examples)
11801211
elif prompt_type == "DynamicFewShotPrompt":
1212+
if "instruction" not in prompt_data:
1213+
raise ValueError(
1214+
"Prompt data missing required 'instruction' field for DynamicFewShotPrompt"
1215+
)
1216+
11811217
if not embedding_model:
11821218
import warnings
11831219

@@ -1380,35 +1416,32 @@ def __repr__(self) -> str:
13801416
"""Return a clean string representation of the metric."""
13811417
metric_type = self.__class__.__name__
13821418

1383-
# Get allowed values in a clean format
13841419
allowed_values = self.allowed_values
1385-
if isinstance(allowed_values, list):
1386-
allowed_values_str = f", allowed_values={allowed_values}"
1387-
elif isinstance(allowed_values, tuple):
1388-
allowed_values_str = f", allowed_values={allowed_values}"
1389-
elif isinstance(allowed_values, range):
1420+
if isinstance(allowed_values, range):
13901421
allowed_values_str = (
13911422
f", allowed_values=({allowed_values.start}, {allowed_values.stop})"
13921423
)
1393-
else:
1424+
elif isinstance(allowed_values, (list, tuple, int)):
13941425
allowed_values_str = f", allowed_values={allowed_values}"
1426+
else:
1427+
allowed_values_str = f", allowed_values={repr(allowed_values)}"
13951428

1396-
# Get prompt string (truncated)
13971429
prompt_str = ""
13981430
if self.prompt:
1399-
if isinstance(self.prompt, str):
1400-
instruction = self.prompt
1401-
else:
1402-
instruction = (
1431+
instruction = (
1432+
self.prompt
1433+
if isinstance(self.prompt, str)
1434+
else (
14031435
self.prompt.instruction
14041436
if hasattr(self.prompt, "instruction")
14051437
else str(self.prompt)
14061438
)
1439+
)
14071440

14081441
if instruction:
1409-
# Truncate long prompts
1410-
if len(instruction) > 80:
1411-
prompt_str = f", prompt='{instruction[:77]}...'"
1442+
max_len = 80
1443+
if len(instruction) > max_len:
1444+
prompt_str = f", prompt='{instruction[: max_len - 3]}...'"
14121445
else:
14131446
prompt_str = f", prompt='{instruction}'"
14141447

src/ragas/metrics/decorator.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -319,22 +319,15 @@ def __call__(self, *args, **kwargs):
319319
return self._func(*args, **kwargs)
320320

321321
def __repr__(self) -> str:
322-
"""Return a clean string representation of the decorator-based metric."""
323-
# Get function signature parameters
322+
from ragas.metrics.validators import get_metric_type_name
323+
324324
param_names = list(sig.parameters.keys())
325325
param_str = ", ".join(param_names)
326326

327-
# Get metric type based on allowed_values
328327
metric_type = "CustomMetric"
329328
if hasattr(self, "allowed_values"):
330-
if isinstance(self.allowed_values, list):
331-
metric_type = "DiscreteMetric"
332-
elif isinstance(self.allowed_values, tuple):
333-
metric_type = "NumericMetric"
334-
elif isinstance(self.allowed_values, int):
335-
metric_type = "RankingMetric"
336-
337-
# Get allowed values string
329+
metric_type = get_metric_type_name(self.allowed_values)
330+
338331
allowed_values_str = ""
339332
if hasattr(self, "allowed_values"):
340333
allowed_values_str = f"[{self.allowed_values!r}]"

src/ragas/metrics/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,4 @@ def discrete_metric(
105105
allowed_values = ["pass", "fail"]
106106

107107
decorator_factory = create_metric_decorator()
108-
return decorator_factory(name=name, allowed_values=allowed_values, **metric_params)
108+
return decorator_factory(name=name, allowed_values=allowed_values, **metric_params) # type: ignore[return-value]

src/ragas/metrics/numeric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,4 @@ def numeric_metric(
119119
allowed_values = (0.0, 1.0)
120120

121121
decorator_factory = create_metric_decorator()
122-
return decorator_factory(name=name, allowed_values=allowed_values, **metric_params)
122+
return decorator_factory(name=name, allowed_values=allowed_values, **metric_params) # type: ignore[return-value]

src/ragas/metrics/ranking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,4 @@ def ranking_metric(
110110
allowed_values = 2
111111

112112
decorator_factory = create_metric_decorator()
113-
return decorator_factory(name=name, allowed_values=allowed_values, **metric_params)
113+
return decorator_factory(name=name, allowed_values=allowed_values, **metric_params) # type: ignore[return-value]

src/ragas/metrics/validators.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
"NumericValidator",
66
"RankingValidator",
77
"AllowedValuesType",
8+
"get_validator_for_allowed_values",
9+
"get_metric_type_name",
810
]
911

1012
import typing as t
@@ -109,3 +111,15 @@ def get_validator_for_allowed_values(
109111
else:
110112
# Default to discrete if unclear
111113
return DiscreteValidator
114+
115+
116+
def get_metric_type_name(allowed_values: AllowedValuesType) -> str:
117+
"""Get the metric type name based on allowed_values type."""
118+
if isinstance(allowed_values, list):
119+
return "DiscreteMetric"
120+
elif isinstance(allowed_values, (tuple, range)):
121+
return "NumericMetric"
122+
elif isinstance(allowed_values, int):
123+
return "RankingMetric"
124+
else:
125+
return "CustomMetric"

tests/unit/test_simple_llm_metric_persistence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ async def aembed_query(self, text: str):
206206

207207
try:
208208
# Save (should warn about embedding model)
209-
with pytest.warns(UserWarning, match="embedding_model cannot be saved"):
209+
with pytest.warns(UserWarning, match="embedding_model will be lost"):
210210
original_metric.save(temp_path)
211211

212212
# Load (provide embedding model)

0 commit comments

Comments
 (0)