|
8 | 8 | from langchain_core.embeddings import Embeddings as LangchainEmbeddings |
9 | 9 | from langchain_core.language_models import BaseLanguageModel as LangchainLLM |
10 | 10 |
|
| 11 | +from llama_index.core.base.llms.base import BaseLLM as LlamaIndexLLM |
| 12 | +from llama_index.core.base.embeddings.base import BaseEmbedding as LlamaIndexEmbedding |
| 13 | + |
11 | 14 | from ragas._analytics import EvaluationEvent, track, track_was_completed |
12 | 15 | from ragas.callbacks import ChainType, RagasTracer, new_group |
13 | 16 | from ragas.dataset_schema import ( |
|
19 | 22 | from ragas.embeddings.base import ( |
20 | 23 | BaseRagasEmbeddings, |
21 | 24 | LangchainEmbeddingsWrapper, |
| 25 | + LlamaIndexEmbeddingsWrapper, |
22 | 26 | embedding_factory, |
23 | 27 | ) |
24 | 28 | from ragas.exceptions import ExceptionInRunner |
25 | 29 | from ragas.executor import Executor |
26 | 30 | from ragas.integrations.helicone import helicone_config |
27 | 31 | from ragas.llms import llm_factory |
28 | | -from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper |
| 32 | +from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper, LlamaIndexLLMWrapper |
29 | 33 | from ragas.metrics import AspectCritic |
30 | 34 | from ragas.metrics._answer_correctness import AnswerCorrectness |
31 | 35 | from ragas.metrics.base import ( |
|
56 | 60 | def evaluate( |
57 | 61 | dataset: t.Union[Dataset, EvaluationDataset], |
58 | 62 | metrics: t.Optional[t.Sequence[Metric]] = None, |
59 | | - llm: t.Optional[BaseRagasLLM | LangchainLLM] = None, |
60 | | - embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None, |
| 63 | + llm: t.Optional[BaseRagasLLM | LangchainLLM | LlamaIndexLLM] = None, |
| 64 | + embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings | LlamaIndexEmbedding] = None, |
61 | 65 | callbacks: Callbacks = None, |
62 | 66 | in_ci: bool = False, |
63 | 67 | run_config: RunConfig = RunConfig(), |
@@ -182,8 +186,12 @@ def evaluate( |
182 | 186 | # set the llm and embeddings |
183 | 187 | if isinstance(llm, LangchainLLM): |
184 | 188 | llm = LangchainLLMWrapper(llm, run_config=run_config) |
| 189 | + elif isinstance(llm, LlamaIndexLLM): |
| 190 | + llm = LlamaIndexLLMWrapper(llm, run_config=run_config) |
185 | 191 | if isinstance(embeddings, LangchainEmbeddings): |
186 | 192 | embeddings = LangchainEmbeddingsWrapper(embeddings) |
| 193 | + elif isinstance(embeddings, LlamaIndexEmbedding): |
| 194 | + embeddings = LlamaIndexEmbeddingsWrapper(embeddings) |
187 | 195 |
|
188 | 196 | # init llms and embeddings |
189 | 197 | binary_metrics = [] |
|
0 commit comments