|
10 | 10 |
|
11 | 11 | from ragas._analytics import EvaluationEvent, track |
12 | 12 | from ragas.callbacks import new_group |
13 | | -from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper |
| 13 | +from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper, embedding_factory |
| 14 | +from ragas.llms import llm_factory |
14 | 15 | from ragas.exceptions import ExceptionInRunner |
15 | 16 | from ragas.executor import Executor |
16 | 17 | from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper |
@@ -57,11 +58,11 @@ def evaluate( |
57 | 58 | evaluation on the best set of metrics to give a complete view. |
58 | 59 | llm: BaseRagasLLM, optional |
59 | 60 | The language model to use for the metrics. If not provided then ragas will use |
60 | | - the default language model. This can we overridden by the llm specified in |
| 61 | + the default language model for metrics which require an LLM. This can we overridden by the llm specified in |
61 | 62 | the metric level with `metric.llm`. |
62 | 63 | embeddings: BaseRagasEmbeddings, optional |
63 | 64 | The embeddings to use for the metrics. If not provided then ragas will use |
64 | | - the default embeddings. This can we overridden by the embeddings specified in |
| 65 | + the default embeddings for metrics which require embeddings. This can we overridden by the embeddings specified in |
65 | 66 | the metric level with `metric.embeddings`. |
66 | 67 | callbacks: Callbacks, optional |
67 | 68 | Lifecycle Langchain Callbacks to run during evaluation. Check the |
@@ -144,34 +145,30 @@ def evaluate( |
144 | 145 | validate_column_dtypes(dataset) |
145 | 146 |
|
146 | 147 | # set the llm and embeddings |
147 | | - if llm is None: |
148 | | - from ragas.llms import llm_factory |
149 | | - |
150 | | - llm = llm_factory() |
151 | | - elif isinstance(llm, LangchainLLM): |
| 148 | + if isinstance(llm, LangchainLLM): |
152 | 149 | llm = LangchainLLMWrapper(llm, run_config=run_config) |
153 | | - if embeddings is None: |
154 | | - from ragas.embeddings.base import embedding_factory |
155 | | - |
156 | | - embeddings = embedding_factory() |
157 | | - elif isinstance(embeddings, LangchainEmbeddings): |
| 150 | + if isinstance(embeddings, LangchainEmbeddings): |
158 | 151 | embeddings = LangchainEmbeddingsWrapper(embeddings) |
| 152 | + |
159 | 153 | # init llms and embeddings |
160 | 154 | binary_metrics = [] |
161 | 155 | llm_changed: t.List[int] = [] |
162 | 156 | embeddings_changed: t.List[int] = [] |
163 | 157 | answer_correctness_is_set = -1 |
| 158 | + |
164 | 159 | for i, metric in enumerate(metrics): |
165 | 160 | if isinstance(metric, AspectCritique): |
166 | 161 | binary_metrics.append(metric.name) |
167 | | - if isinstance(metric, MetricWithLLM): |
168 | | - if metric.llm is None: |
169 | | - metric.llm = llm |
170 | | - llm_changed.append(i) |
171 | | - if isinstance(metric, MetricWithEmbeddings): |
172 | | - if metric.embeddings is None: |
173 | | - metric.embeddings = embeddings |
174 | | - embeddings_changed.append(i) |
| 162 | + if isinstance(metric, MetricWithLLM) and metric.llm is None: |
| 163 | + if llm is None: |
| 164 | + llm = llm_factory() |
| 165 | + metric.llm = llm |
| 166 | + llm_changed.append(i) |
| 167 | + if isinstance(metric, MetricWithEmbeddings) and metric.embeddings is None: |
| 168 | + if embeddings is None: |
| 169 | + embeddings = embedding_factory() |
| 170 | + metric.embeddings = embeddings |
| 171 | + embeddings_changed.append(i) |
175 | 172 | if isinstance(metric, AnswerCorrectness): |
176 | 173 | if metric.answer_similarity is None: |
177 | 174 | answer_correctness_is_set = i |
|
0 commit comments