Skip to content

Commit fcaf4d0

Browse files
add embeddings to TestsetGenerator (#1562)
Addresses the "no embeddings found" and "API Connection error" issues. Specifically issues: [1546](#1546), [1526](#1526), [1512](#1512), [1496](#1496) Users have reported that they cannot generate a Testset because they get API connection errors, or their knowledge graph does not have the embeddings. This is due to the use of the default LLMs and Embedding models via llm_factory and embedding_factory. The errors are occuring becuase the users do not have OpenAI credentials in their environment because they are using different models in their workflow. Issue to solve is to prevent the default_transforms function from using the llm_factory by forcing the user to add both an embedding model and llm model when instantiating TestsetGenerator. 1. Added `embedding_model` as an attribute to `TestsetGenerator`. 2. Added `embedding_model: LangchainEmbeddings` as a parameter to `TestsetGenerator.from_langchain` 3. Changed the return from `TestsetGenerator.from_langchain` to `return cls(LangchainLLMWrapper(llm), LangchainEmbeddingsWrapper(embedding_model), knowledge_graph)` 4. Added both an `llm` and `embedding_model` parameter to `TestsetGenerator.generate_with_langchain_docs`
1 parent ade46fb commit fcaf4d0

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

src/ragas/testset/synthesizers/generate.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ragas.executor import Executor
1313
from ragas.llms import BaseRagasLLM, LangchainLLMWrapper
1414
from ragas.run_config import RunConfig
15+
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
1516
from ragas.testset.graph import KnowledgeGraph, Node, NodeType
1617
from ragas.testset.synthesizers import default_query_distribution
1718
from ragas.testset.synthesizers.testset_schema import Testset, TestsetSample
@@ -22,6 +23,7 @@
2223
from langchain_core.callbacks import Callbacks
2324
from langchain_core.documents import Document as LCDocument
2425
from langchain_core.language_models import BaseLanguageModel as LangchainLLM
26+
from langchain_core.embeddings.embeddings import Embeddings as LangchainEmbeddings
2527

2628
from ragas.embeddings.base import BaseRagasEmbeddings
2729
from ragas.llms.base import BaseRagasLLM
@@ -42,24 +44,32 @@ class TestsetGenerator:
4244
----------
4345
llm : BaseRagasLLM
4446
The language model to use for the generation process.
47+
embedding_model: BaseRagasEmbeddings
48+
Embedding model for generation process.
4549
knowledge_graph : KnowledgeGraph, default empty
4650
The knowledge graph to use for the generation process.
4751
"""
4852

4953
llm: BaseRagasLLM
54+
embedding_model: BaseRagasEmbeddings
5055
knowledge_graph: KnowledgeGraph = field(default_factory=KnowledgeGraph)
5156

5257
@classmethod
5358
def from_langchain(
5459
cls,
5560
llm: LangchainLLM,
61+
embedding_model: LangchainEmbeddings,
5662
knowledge_graph: t.Optional[KnowledgeGraph] = None,
5763
) -> TestsetGenerator:
5864
"""
5965
Creates a `TestsetGenerator` from a Langchain LLMs.
6066
"""
6167
knowledge_graph = knowledge_graph or KnowledgeGraph()
62-
return cls(LangchainLLMWrapper(llm), knowledge_graph)
68+
return cls(
69+
LangchainLLMWrapper(llm),
70+
LangchainEmbeddingsWrapper(embedding_model),
71+
knowledge_graph
72+
)
6373

6474
def generate_with_langchain_docs(
6575
self,
@@ -77,19 +87,26 @@ def generate_with_langchain_docs(
7787
"""
7888
Generates an evaluation dataset based on given scenarios and parameters.
7989
"""
80-
if transforms is None:
81-
# use default transforms
82-
if transforms_llm is None:
83-
transforms_llm = self.llm
84-
logger.info("Using TestGenerator.llm for transforms")
85-
if transforms_embedding_model is None:
86-
raise ValueError(
87-
"embedding_model must be provided for default_transforms. Alternatively you can provide your own transforms through the `transforms` parameter."
90+
91+
# force the user to provide an llm and embedding client to prevent use of default LLMs
92+
if not self.llm and not transforms_llm:
93+
raise ValueError(
94+
'''An llm client was not provided.
95+
Provide an LLM on TestsetGenerator instantiation or as an argument for transforms_llm parameter.
96+
Alternatively you can provide your own transforms through the `transforms` parameter.'''
97+
)
98+
if not self.embedding_model and not transforms_embedding_model:
99+
raise ValueError(
100+
'''An embedding client was not provided.
101+
Provide an embedding model on TestsetGenerator instantiation or as an argument for transforms_llm parameter.
102+
Alternatively you can provide your own transforms through the `transforms` parameter.'''
88103
)
104+
105+
if not transforms:
89106
transforms = default_transforms(
90-
llm=transforms_llm or self.llm,
91-
embedding_model=transforms_embedding_model,
92-
)
107+
llm=transforms_llm or self.llm,
108+
embedding_model=transforms_embedding_model or self.embedding_model
109+
)
93110

94111
# convert the documents to Ragas nodes
95112
nodes = []

0 commit comments

Comments
 (0)