Skip to content

Commit f6a932a

Browse files
authored
Made embeddings and LLMs dependent on metric in evaluate function (#628)
Since we already specify whether or not a metric requires an LLM/embeddings model via inheritence from `MetricWithLLM` and `MetricWithEmbeddings`, there isn't really a need to force the use of default LLMs/embeddings in `evaluate` if, for example, no metrics that need embeddings are specified in `metrics`. I believe that initiating an LLM/embedding model at the metric level will help clarify how to use `evaluate`, and will make things simpler in the future when more metrics are added as it decouples needing to initialize LLMs/embedding models for metrics that potentially don't need it. They are even optional arguments to the function itself. **Copilot Description** This pull request mainly refactors the `evaluate` function in the `src/ragas/evaluation.py` file. The changes aim to optimize the import and usage of `llm_factory` and `embedding_factory`, and clarify the function comments. Here are the main changes: * [`src/ragas/evaluation.py`](diffhunk://#diff-ae27b15b33603128d151769a7a1a11ed36bd8151ff2326f81e1478889f87c91fL13-R14): Two new imports were added to the top of the file: `embedding_factory` and `llm_factory` from `ragas.embeddings.base` and `ragas.llms` respectively. This change helps to avoid repetitive imports within the `evaluate` function. Changes within the `evaluate` function: * The comments for the `llm` and `embeddings` parameters were updated to specify that the default language model and embeddings are used for metrics which require an LLM or embeddings. This provides more clarity on the function's behavior. * The conditional logic for setting `llm` and `embeddings` was simplified. The `llm_factory` and `embedding_factory` are now only called when `llm` and `embeddings` are `None` and the corresponding metric requires them. This change removes the need for importing `llm_factory` and `embedding_factory` inside the function.
1 parent 402dc7e commit f6a932a

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

src/ragas/evaluation.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from ragas._analytics import EvaluationEvent, track
1212
from ragas.callbacks import new_group
13-
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
13+
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper, embedding_factory
14+
from ragas.llms import llm_factory
1415
from ragas.exceptions import ExceptionInRunner
1516
from ragas.executor import Executor
1617
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
@@ -57,11 +58,11 @@ def evaluate(
5758
evaluation on the best set of metrics to give a complete view.
5859
llm: BaseRagasLLM, optional
5960
The language model to use for the metrics. If not provided then ragas will use
60-
the default language model. This can we overridden by the llm specified in
61+
the default language model for metrics which require an LLM. This can we overridden by the llm specified in
6162
the metric level with `metric.llm`.
6263
embeddings: BaseRagasEmbeddings, optional
6364
The embeddings to use for the metrics. If not provided then ragas will use
64-
the default embeddings. This can we overridden by the embeddings specified in
65+
the default embeddings for metrics which require embeddings. This can we overridden by the embeddings specified in
6566
the metric level with `metric.embeddings`.
6667
callbacks: Callbacks, optional
6768
Lifecycle Langchain Callbacks to run during evaluation. Check the
@@ -144,34 +145,30 @@ def evaluate(
144145
validate_column_dtypes(dataset)
145146

146147
# set the llm and embeddings
147-
if llm is None:
148-
from ragas.llms import llm_factory
149-
150-
llm = llm_factory()
151-
elif isinstance(llm, LangchainLLM):
148+
if isinstance(llm, LangchainLLM):
152149
llm = LangchainLLMWrapper(llm, run_config=run_config)
153-
if embeddings is None:
154-
from ragas.embeddings.base import embedding_factory
155-
156-
embeddings = embedding_factory()
157-
elif isinstance(embeddings, LangchainEmbeddings):
150+
if isinstance(embeddings, LangchainEmbeddings):
158151
embeddings = LangchainEmbeddingsWrapper(embeddings)
152+
159153
# init llms and embeddings
160154
binary_metrics = []
161155
llm_changed: t.List[int] = []
162156
embeddings_changed: t.List[int] = []
163157
answer_correctness_is_set = -1
158+
164159
for i, metric in enumerate(metrics):
165160
if isinstance(metric, AspectCritique):
166161
binary_metrics.append(metric.name)
167-
if isinstance(metric, MetricWithLLM):
168-
if metric.llm is None:
169-
metric.llm = llm
170-
llm_changed.append(i)
171-
if isinstance(metric, MetricWithEmbeddings):
172-
if metric.embeddings is None:
173-
metric.embeddings = embeddings
174-
embeddings_changed.append(i)
162+
if isinstance(metric, MetricWithLLM) and metric.llm is None:
163+
if llm is None:
164+
llm = llm_factory()
165+
metric.llm = llm
166+
llm_changed.append(i)
167+
if isinstance(metric, MetricWithEmbeddings) and metric.embeddings is None:
168+
if embeddings is None:
169+
embeddings = embedding_factory()
170+
metric.embeddings = embeddings
171+
embeddings_changed.append(i)
175172
if isinstance(metric, AnswerCorrectness):
176173
if metric.answer_similarity is None:
177174
answer_correctness_is_set = i

0 commit comments

Comments
 (0)