1010from ragas ._analytics import TestsetGenerationEvent , track
1111from ragas .callbacks import new_group
1212from ragas .cost import TokenUsageParser
13- from ragas .embeddings .base import BaseRagasEmbeddings , LlamaIndexEmbeddingsWrapper
13+ from ragas .embeddings .base import (
14+ BaseRagasEmbeddings ,
15+ LangchainEmbeddingsWrapper ,
16+ LlamaIndexEmbeddingsWrapper ,
17+ )
1418from ragas .executor import Executor
1519from ragas .llms import BaseRagasLLM , LangchainLLMWrapper , LlamaIndexLLMWrapper
1620from ragas .run_config import RunConfig
2428if t .TYPE_CHECKING :
2529 from langchain_core .callbacks import Callbacks
2630 from langchain_core .documents import Document as LCDocument
31+ from langchain_core .embeddings import Embeddings as LangchainEmbeddings
2732 from langchain_core .language_models import BaseLanguageModel as LangchainLLM
2833 from llama_index .core .base .embeddings .base import (
2934 BaseEmbedding as LlamaIndexEmbedding ,
@@ -55,13 +60,15 @@ class TestsetGenerator:
5560 """
5661
5762 llm : BaseRagasLLM
63+ embedding_model : BaseRagasEmbeddings
5864 knowledge_graph : KnowledgeGraph = field (default_factory = KnowledgeGraph )
5965 persona_list : t .Optional [t .List [Persona ]] = None
6066
6167 @classmethod
6268 def from_langchain (
6369 cls ,
6470 llm : LangchainLLM ,
71+ embedding_model : LangchainEmbeddings ,
6572 knowledge_graph : t .Optional [KnowledgeGraph ] = None ,
6673 ) -> TestsetGenerator :
6774 """
@@ -70,13 +77,15 @@ def from_langchain(
7077 knowledge_graph = knowledge_graph or KnowledgeGraph ()
7178 return cls (
7279 LangchainLLMWrapper (llm ),
80+ LangchainEmbeddingsWrapper (embedding_model ),
7381 knowledge_graph ,
7482 )
7583
7684 @classmethod
7785 def from_llama_index (
7886 cls ,
7987 llm : LlamaIndexLLM ,
88+ embedding_model : LlamaIndexEmbedding ,
8089 knowledge_graph : t .Optional [KnowledgeGraph ] = None ,
8190 ) -> TestsetGenerator :
8291 """
@@ -85,6 +94,7 @@ def from_llama_index(
8594 knowledge_graph = knowledge_graph or KnowledgeGraph ()
8695 return cls (
8796 LlamaIndexLLMWrapper (llm ),
97+ LlamaIndexEmbeddingsWrapper (embedding_model ),
8898 knowledge_graph ,
8999 )
90100
@@ -145,7 +155,7 @@ def generate_with_langchain_docs(
145155 Provide an LLM on TestsetGenerator instantiation or as an argument for transforms_llm parameter.
146156 Alternatively you can provide your own transforms through the `transforms` parameter."""
147157 )
148- if not transforms_embedding_model :
158+ if not self . embedding_model and not transforms_embedding_model :
149159 raise ValueError (
150160 """An embedding client was not provided. Provide an embedding through the transforms_embedding_model parameter. Alternatively you can provide your own transforms through the `transforms` parameter."""
151161 )
@@ -154,7 +164,7 @@ def generate_with_langchain_docs(
154164 transforms = default_transforms (
155165 documents = list (documents ),
156166 llm = transforms_llm or self .llm ,
157- embedding_model = transforms_embedding_model ,
167+ embedding_model = transforms_embedding_model or self . embedding_model ,
158168 )
159169
160170 # convert the documents to Ragas nodes
@@ -208,19 +218,25 @@ def generate_with_llamaindex_docs(
208218 raise ValueError (
209219 "An llm client was not provided. Provide an LLM on TestsetGenerator instantiation or as an argument for transforms_llm parameter. Alternatively you can provide your own transforms through the `transforms` parameter."
210220 )
211- if not transforms_embedding_model :
221+ if not self . embedding_model and not transforms_embedding_model :
212222 raise ValueError (
213223 "An embedding client was not provided. Provide an embedding through the transforms_embedding_model parameter. Alternatively you can provide your own transforms through the `transforms` parameter."
214224 )
215225
216226 if not transforms :
227+ # use TestsetGenerator's LLM and embedding model if no transforms_llm or transforms_embedding_model is provided
217228 if transforms_llm is None :
218229 llm_for_transforms = self .llm
219230 else :
220231 llm_for_transforms = LlamaIndexLLMWrapper (transforms_llm )
221- embedding_model_for_transforms = LlamaIndexEmbeddingsWrapper (
222- transforms_embedding_model
223- )
232+ if transforms_embedding_model is None :
233+ embedding_model_for_transforms = self .embedding_model
234+ else :
235+ embedding_model_for_transforms = LlamaIndexEmbeddingsWrapper (
236+ transforms_embedding_model
237+ )
238+
239+ # create the transforms
224240 transforms = default_transforms (
225241 documents = [LCDocument (page_content = doc .text ) for doc in documents ],
226242 llm = llm_for_transforms ,
@@ -371,7 +387,7 @@ def generate(
371387
372388 # generate scenarios
373389 exec = Executor (
374- "Generating Scenarios" ,
390+ desc = "Generating Scenarios" ,
375391 raise_exceptions = raise_exceptions ,
376392 run_config = run_config ,
377393 keep_progress_bar = False ,
0 commit comments