77
88import pandas as pd
99from datasets import Dataset
10- from langchain_core .embeddings import Embeddings
11- from langchain_core .language_models import BaseLanguageModel
1210from langchain_openai .chat_models import ChatOpenAI
1311from langchain_openai .embeddings import OpenAIEmbeddings
1412
1513from ragas ._analytics import TestsetGenerationEvent , track
16- from ragas .embeddings .base import BaseRagasEmbeddings , LangchainEmbeddingsWrapper
14+ from ragas .embeddings .base import (
15+ BaseRagasEmbeddings ,
16+ LangchainEmbeddingsWrapper ,
17+ LlamaIndexEmbeddingsWrapper ,
18+ )
1719from ragas .exceptions import ExceptionInRunner
1820from ragas .executor import Executor
19- from ragas .llms import BaseRagasLLM , LangchainLLMWrapper
21+ from ragas .llms import BaseRagasLLM , LangchainLLMWrapper , LlamaIndexLLMWrapper
2022from ragas .run_config import RunConfig
2123from ragas .testset .docstore import Document , DocumentStore , InMemoryDocumentStore
2224from ragas .testset .evolutions import (
3436
3537if t .TYPE_CHECKING :
3638 from langchain_core .documents import Document as LCDocument
39+ from langchain_core .embeddings import Embeddings as LangchainEmbeddings
40+ from langchain_core .language_models import BaseLanguageModel as LangchainLLM
41+ from llama_index .core .base .embeddings .base import (
42+ BaseEmbedding as LlamaIndexEmbeddings ,
43+ )
44+ from llama_index .core .base .llms .base import BaseLLM as LlamaindexLLM
3745 from llama_index .core .schema import Document as LlamaindexDocument
3846
3947logger = logging .getLogger (__name__ )
@@ -75,9 +83,9 @@ class TestsetGenerator:
7583 @classmethod
7684 def from_langchain (
7785 cls ,
78- generator_llm : BaseLanguageModel ,
79- critic_llm : BaseLanguageModel ,
80- embeddings : Embeddings ,
86+ generator_llm : LangchainLLM ,
87+ critic_llm : LangchainLLM ,
88+ embeddings : LangchainEmbeddings ,
8189 docstore : t .Optional [DocumentStore ] = None ,
8290 run_config : t .Optional [RunConfig ] = None ,
8391 chunk_size : int = 1024 ,
@@ -104,6 +112,36 @@ def from_langchain(
104112 docstore = docstore ,
105113 )
106114
115+ @classmethod
116+ def from_llama_index (
117+ cls ,
118+ generator_llm : LlamaindexLLM ,
119+ critic_llm : LlamaindexLLM ,
120+ embeddings : LlamaIndexEmbeddings ,
121+ docstore : t .Optional [DocumentStore ] = None ,
122+ run_config : t .Optional [RunConfig ] = None ,
123+ ) -> "TestsetGenerator" :
124+ generator_llm_model = LlamaIndexLLMWrapper (generator_llm )
125+ critic_llm_model = LlamaIndexLLMWrapper (critic_llm )
126+ embeddings_model = LlamaIndexEmbeddingsWrapper (embeddings )
127+ keyphrase_extractor = KeyphraseExtractor (llm = generator_llm_model )
128+ if docstore is None :
129+ from langchain .text_splitter import TokenTextSplitter
130+
131+ splitter = TokenTextSplitter (chunk_size = 1024 , chunk_overlap = 0 )
132+ docstore = InMemoryDocumentStore (
133+ splitter = splitter ,
134+ embeddings = embeddings_model ,
135+ extractor = keyphrase_extractor ,
136+ run_config = run_config ,
137+ )
138+ return cls (
139+ generator_llm = generator_llm_model ,
140+ critic_llm = critic_llm_model ,
141+ embeddings = embeddings_model ,
142+ docstore = docstore ,
143+ )
144+
107145 @classmethod
108146 @deprecated ("0.1.4" , removal = "0.2.0" , alternative = "from_langchain" )
109147 def with_openai (
0 commit comments