Skip to content

Commit 80e77b6

Browse files
jjmachananistark
andauthored
fix: metric inheritance patterns: separate factory-created metrics from class-instantiated metrics (#2316)
## Summary **Primary Motivation**: This PR fixes a fundamental inheritance pattern issue in the metrics system where factory-created metrics (via `@discrete_metric`, `@numeric_metric`, etc.) and class-instantiated metrics (via `DiscreteMetric()`, `NumericMetric()`, etc.) should have different base classes but were incorrectly sharing the same inheritance hierarchy. **The Problem**: - Factory-created metrics should inherit from `SimpleBaseMetric` (lightweight, decorator-based) - Class-instantiated metrics should inherit from `SimpleLLMMetric` (LLM-enabled, full-featured) - Previously, both paths incorrectly inherited from the same base classes, creating confusion and incorrect behavior **The Solution**: • **Separated base classes**: Created `SimpleBaseMetric` (for factory) and `SimpleLLMMetric` (for class instantiation) as distinct, unrelated base classes • **Removed `llm_based.py`**: Consolidated `BaseLLMMetric` and `LLMMetric` into `base.py` as `SimpleBaseMetric` and `SimpleLLMMetric` • **Fixed decorator inheritance**: Factory methods now create metrics that inherit from `SimpleBaseMetric + ValidatorMixin` only • **Fixed class inheritance**: Class-based metrics like `DiscreteMetric` now inherit from `SimpleLLMMetric + ValidatorMixin` • **Added validator system**: Introduced modular validation mixins that work with both inheritance patterns • **Maintained backward compatibility**: Added aliases `BaseMetric = SimpleBaseMetric` and `LLMMetric = SimpleLLMMetric` **Exact Steps Taken**: 1. `7d6de2a` - Updated gitignore for experimental directories 2. `c6101f8` - Renamed classes and established proper naming convention 3. `46450d8` - Refactored decorator and class-based inheritance patterns 4. `a464c37` - Simplified validator system with proper mixins 5. `fe996f6` - Removed `llm_based.py` after consolidation ## Test plan - [ ] Verify factory-created metrics (`@discrete_metric`) inherit from `SimpleBaseMetric` only - [ ] Verify class-instantiated metrics (`DiscreteMetric()`) inherit from `SimpleLLMMetric` - [ ] Test that both patterns work correctly with their respective validation mixins - [ ] Ensure backward compatibility with existing metric imports - [ ] Validate all metric functionality (scoring, async operations, alignment) - [ ] Run full test suite to ensure no regressions 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> --------- Co-authored-by: Ani <[email protected]>
1 parent 1dea897 commit 80e77b6

File tree

11 files changed

+468
-361
lines changed

11 files changed

+468
-361
lines changed

.claude/commands/create-pr.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--
2+
WHITELIST_AFTER_APPROVAL:
3+
- Bash(git push:*)
4+
- Bash(gh pr create:*)
5+
-->
6+
7+
Please create a pull request with the changes on the current branch. $ARGUMENTS
8+
9+
**IMPORTANT: Use planning mode for this command.**
10+
11+
Follow these steps:
12+
13+
## Planning Phase (Research Only):
14+
1. Check git status to ensure working tree is clean
15+
2. Verify current branch is not main/master
16+
3. Check if current branch tracks a remote branch and is up to date
17+
4. Run `git log --oneline main..HEAD` to see commits that will be included in PR
18+
5. Run `git diff main...HEAD` to understand the full scope of changes
19+
6. Analyze all changes and commits to create a comprehensive PR summary
20+
7. Draft PR title and description with:
21+
- Title: Clear, descriptive title based on the changes
22+
- Body: Include "## Summary" with bullet points, "## Test plan" checklist, and Claude signature
23+
8. **Present the draft PR title and description using ExitPlanMode tool for user approval**
24+
25+
## Execution Phase (After User Approval):
26+
9. Push branch to remote with upstream tracking if needed
27+
10. Create PR using `gh pr create` with the approved title and description
28+
11. Return the PR URL for easy access
29+
30+
Use the GitHub CLI (`gh`) for all GitHub-related operations. If $ARGUMENTS are provided, incorporate them as hints for the PR title or description.

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ examples/build/
196196
examples/*.egg-info/
197197
examples/ragas_examples/_version.py
198198
examples/ragas_examples/text2sql/experiments/*
199+
examples/ragas_examples/benchmark_llm/experiments/*
200+
BookSQL-files
201+
text2sql_logs
199202

200203
# MLflow artifacts
201204
mlartifacts

src/ragas/metrics/__init__.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@
7575
MetricWithEmbeddings,
7676
MetricWithLLM,
7777
MultiTurnMetric,
78+
SimpleBaseMetric as BaseMetric,
79+
SimpleLLMMetric as LLMMetric,
7880
SingleTurnMetric,
7981
)
8082
from ragas.metrics.discrete import DiscreteMetric, discrete_metric
81-
from ragas.metrics.llm_based import BaseLLMMetric, LLMMetric
8283
from ragas.metrics.numeric import NumericMetric, numeric_metric
8384
from ragas.metrics.ranking import RankingMetric, ranking_metric
8485
from ragas.metrics.result import MetricResult
@@ -93,7 +94,7 @@
9394
"MultiTurnMetric",
9495
"MetricOutputType",
9596
# LLM-based metrics (moved from experimental)
96-
"BaseLLMMetric",
97+
"BaseMetric",
9798
"LLMMetric",
9899
"MetricResult",
99100
"DiscreteMetric",
@@ -160,7 +161,3 @@
160161
"MultiModalRelevance",
161162
"multimodal_relevance",
162163
]
163-
164-
# Backward compatibility aliases for experimental imports
165-
# These maintain compatibility while experimental code is migrated
166-
BaseMetric = BaseLLMMetric # experimental BaseMetric -> BaseLLMMetric

src/ragas/metrics/base.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@
2222

2323
if t.TYPE_CHECKING:
2424
from langchain_core.callbacks import Callbacks
25+
from pydantic import BaseModel
2526

2627
from ragas.config import DemonstrationConfig, InstructionConfig
28+
from ragas.dataset import Dataset
2729
from ragas.embeddings import BaseRagasEmbedding, BaseRagasEmbeddings
2830
from ragas.llms import BaseRagasLLM
31+
from ragas.metrics.result import MetricResult
32+
from ragas.prompt.simple_prompt import Prompt
2933

3034
logger = logging.getLogger(__name__)
3135

@@ -723,3 +727,297 @@ class ModeMetric(t.Protocol):
723727

724728

725729
ensembler = Ensember()
730+
731+
732+
@dataclass
733+
class SimpleBaseMetric(ABC):
734+
"""Base class for simple metrics that return MetricResult objects."""
735+
736+
name: str
737+
738+
@abstractmethod
739+
def score(self, **kwargs) -> "MetricResult":
740+
pass
741+
742+
@abstractmethod
743+
async def ascore(self, **kwargs) -> "MetricResult":
744+
pass
745+
746+
def batch_score(
747+
self,
748+
inputs: t.List[t.Dict[str, t.Any]],
749+
) -> t.List["MetricResult"]:
750+
return [self.score(**input_dict) for input_dict in inputs]
751+
752+
async def abatch_score(
753+
self,
754+
inputs: t.List[t.Dict[str, t.Any]],
755+
) -> t.List["MetricResult"]:
756+
async_tasks = []
757+
for input_dict in inputs:
758+
# Process input asynchronously
759+
async_tasks.append(self.ascore(**input_dict))
760+
761+
# Run all tasks concurrently and return results
762+
return await asyncio.gather(*async_tasks)
763+
764+
765+
@dataclass
766+
class SimpleLLMMetric(SimpleBaseMetric):
767+
"""LLM-based metric that uses prompts to generate structured responses."""
768+
769+
prompt: t.Optional[t.Union[str, "Prompt"]] = None
770+
_response_model: t.Type["BaseModel"] = field(init=False)
771+
772+
def __post_init__(self):
773+
if isinstance(self.prompt, str):
774+
from ragas.prompt.simple_prompt import Prompt
775+
776+
self.prompt = Prompt(self.prompt)
777+
778+
def get_variables(self) -> t.List[str]:
779+
if isinstance(self.prompt, (type(None), str)):
780+
fstr = self.prompt
781+
else:
782+
fstr = self.prompt.instruction
783+
if fstr is None:
784+
return []
785+
import string
786+
787+
vars = [
788+
field_name
789+
for _, field_name, _, _ in string.Formatter().parse(fstr)
790+
if field_name
791+
]
792+
return vars
793+
794+
def score(self, **kwargs) -> "MetricResult":
795+
from ragas.metrics.result import MetricResult
796+
797+
llm = kwargs.pop("llm") # Extract llm from kwargs for compatibility
798+
traces = {}
799+
traces["input"] = kwargs
800+
801+
# get prompt
802+
if not self.prompt:
803+
raise Exception("prompt not passed")
804+
prompt_input = self.prompt.format(**kwargs)
805+
806+
response = llm.generate(prompt_input, response_model=self._response_model)
807+
traces["output"] = response.model_dump()
808+
result = MetricResult(**response.model_dump())
809+
result.traces = traces
810+
return result
811+
812+
async def ascore(self, **kwargs) -> "MetricResult":
813+
from ragas.metrics.result import MetricResult
814+
815+
llm = kwargs.pop("llm") # Extract llm from kwargs for compatibility
816+
traces = {}
817+
818+
# get prompt
819+
if not self.prompt:
820+
raise Exception("prompt not passed")
821+
prompt_input = self.prompt.format(**kwargs)
822+
823+
traces["input"] = prompt_input
824+
response = await llm.agenerate(
825+
prompt_input,
826+
response_model=self._response_model,
827+
)
828+
traces["output"] = response.model_dump()
829+
result = MetricResult(**response.model_dump()) # Fixed missing parentheses
830+
result.traces = traces
831+
return result
832+
833+
def batch_score(
834+
self, inputs: t.List[t.Dict[str, t.Any]], **kwargs
835+
) -> t.List["MetricResult"]:
836+
# Override base method to maintain compatibility
837+
llm = kwargs.get("llm") or inputs[0].get("llm") if inputs else None
838+
if llm:
839+
# Add llm to each input
840+
inputs_with_llm = [{**input_dict, "llm": llm} for input_dict in inputs]
841+
return super().batch_score(inputs_with_llm)
842+
return super().batch_score(inputs)
843+
844+
async def abatch_score(
845+
self, inputs: t.List[t.Dict[str, t.Any]], **kwargs
846+
) -> t.List["MetricResult"]:
847+
# Override base method to maintain compatibility
848+
llm = kwargs.get("llm") or inputs[0].get("llm") if inputs else None
849+
if llm:
850+
# Add llm to each input
851+
inputs_with_llm = [{**input_dict, "llm": llm} for input_dict in inputs]
852+
return await super().abatch_score(inputs_with_llm)
853+
return await super().abatch_score(inputs)
854+
855+
@abstractmethod
856+
def get_correlation(
857+
self, gold_labels: t.List[str], predictions: t.List[str]
858+
) -> float:
859+
"""
860+
Calculate the correlation between gold scores and predicted scores.
861+
This is a placeholder method and should be implemented based on the specific metric.
862+
"""
863+
pass
864+
865+
def align_and_validate(
866+
self,
867+
dataset: "Dataset",
868+
embedding_model: t.Union["BaseRagasEmbeddings", "BaseRagasEmbedding"],
869+
llm: "BaseRagasLLM",
870+
test_size: float = 0.2,
871+
random_state: int = 42,
872+
**kwargs: t.Dict[str, t.Any],
873+
):
874+
"""
875+
Args:
876+
dataset: experiment to align the metric with.
877+
embedding_model: The embedding model used for dynamic few-shot prompting.
878+
llm: The LLM instance to use for scoring.
879+
880+
Align the metric with the specified experiments and validate it against a gold standard experiment.
881+
This method combines alignment and validation into a single step.
882+
"""
883+
train_dataset, test_dataset = dataset.train_test_split(
884+
test_size=test_size, random_state=random_state
885+
)
886+
887+
self.align(train_dataset, embedding_model, **kwargs) # type: ignore
888+
return self.validate_alignment(llm, test_dataset) # type: ignore
889+
890+
def align(
891+
self,
892+
train_dataset: "Dataset",
893+
embedding_model: t.Union["BaseRagasEmbeddings", "BaseRagasEmbedding"],
894+
**kwargs: t.Dict[str, t.Any],
895+
):
896+
"""
897+
Args:
898+
train_dataset: train_dataset to align the metric with.
899+
embedding_model: The embedding model used for dynamic few-shot prompting.
900+
901+
Align the metric with the specified experiments by different optimization methods.
902+
"""
903+
904+
# get prompt
905+
if not self.prompt:
906+
raise Exception("prompt not passed")
907+
from ragas.prompt.simple_prompt import Prompt
908+
909+
self.prompt = (
910+
self.prompt if isinstance(self.prompt, Prompt) else Prompt(self.prompt)
911+
)
912+
# Extract specific parameters for from_prompt method
913+
max_similar_examples_val = kwargs.get("max_similar_examples", 3)
914+
similarity_threshold_val = kwargs.get("similarity_threshold", 0.7)
915+
max_similar_examples = (
916+
int(max_similar_examples_val)
917+
if isinstance(max_similar_examples_val, (int, str))
918+
else 3
919+
)
920+
similarity_threshold = (
921+
float(similarity_threshold_val)
922+
if isinstance(similarity_threshold_val, (int, float, str))
923+
else 0.7
924+
)
925+
# Convert BaseRagasEmbeddings to BaseRagasEmbedding if needed
926+
if hasattr(embedding_model, "embed_query"):
927+
# For legacy BaseRagasEmbeddings, we need to wrap it
928+
# Create a wrapper that implements BaseRagasEmbedding interface
929+
class EmbeddingWrapper:
930+
def __init__(self, legacy_embedding):
931+
self.legacy_embedding = legacy_embedding
932+
933+
def embed_text(self, text: str, **kwargs) -> t.List[float]:
934+
return self.legacy_embedding.embed_query(text)
935+
936+
async def aembed_text(self, text: str, **kwargs) -> t.List[float]:
937+
return await self.legacy_embedding.aembed_query(text)
938+
939+
actual_embedding_model = EmbeddingWrapper(embedding_model)
940+
else:
941+
# Already BaseRagasEmbedding
942+
actual_embedding_model = embedding_model
943+
944+
from ragas.prompt.dynamic_few_shot import DynamicFewShotPrompt
945+
946+
self.prompt = DynamicFewShotPrompt.from_prompt(
947+
self.prompt,
948+
actual_embedding_model, # type: ignore[arg-type]
949+
max_similar_examples,
950+
similarity_threshold,
951+
)
952+
train_dataset.reload()
953+
total_items = len(train_dataset)
954+
input_vars = self.get_variables()
955+
output_vars = [self.name, f"{self.name}_reason"]
956+
957+
from rich.progress import Progress
958+
959+
with Progress() as progress:
960+
task = progress.add_task("Processing examples", total=total_items)
961+
for row in train_dataset:
962+
inputs = {
963+
var: train_dataset.get_row_value(row, var) for var in input_vars
964+
}
965+
inputs = {k: v for k, v in inputs.items() if v is not None}
966+
output = {
967+
var: train_dataset.get_row_value(row, var) for var in output_vars
968+
}
969+
output = {k: v for k, v in output.items() if v is not None}
970+
971+
if output:
972+
self.prompt.add_example(inputs, output)
973+
progress.update(task, advance=1)
974+
975+
def validate_alignment(
976+
self,
977+
llm: "BaseRagasLLM",
978+
test_dataset: "Dataset",
979+
mapping: t.Dict[str, str] = {},
980+
):
981+
"""
982+
Args:
983+
llm: The LLM instance to use for scoring.
984+
test_dataset: An Dataset instance containing the gold standard scores.
985+
mapping: A dictionary mapping variable names expected by metrics to their corresponding names in the gold experiment.
986+
987+
Validate the alignment of the metric by comparing the scores against a gold standard experiment.
988+
This method computes the Cohen's Kappa score and agreement rate between the gold standard scores and
989+
the predicted scores from the metric.
990+
"""
991+
992+
test_dataset.reload()
993+
gold_scores_raw = [
994+
test_dataset.get_row_value(row, self.name) for row in test_dataset
995+
]
996+
pred_scores = []
997+
for row in test_dataset:
998+
values = {
999+
v: (
1000+
test_dataset.get_row_value(row, v)
1001+
if v not in mapping
1002+
else test_dataset.get_row_value(row, mapping.get(v, v))
1003+
)
1004+
for v in self.get_variables()
1005+
}
1006+
score = self.score(llm=llm, **values)
1007+
pred_scores.append(score.value)
1008+
1009+
# Convert to strings for correlation calculation, filtering out None values
1010+
gold_scores = [str(score) for score in gold_scores_raw if score is not None]
1011+
pred_scores_str = [str(score) for score in pred_scores if score is not None]
1012+
1013+
df = test_dataset.to_pandas()
1014+
df[f"{self.name}_pred"] = pred_scores
1015+
correlation = self.get_correlation(gold_scores, pred_scores_str)
1016+
agreement_rate = sum(
1017+
x == y for x, y in zip(gold_scores, pred_scores_str)
1018+
) / len(gold_scores)
1019+
return {
1020+
"correlation": correlation,
1021+
"agreement_rate": agreement_rate,
1022+
"df": df,
1023+
}

0 commit comments

Comments
 (0)