Skip to content

Commit fc6ef22

Browse files
authored
feat: added generate_from_langchain function (#511)
1 parent d748049 commit fc6ef22

File tree

3 files changed

+40
-8
lines changed

3 files changed

+40
-8
lines changed

src/ragas/testset/docstore.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import heapq
24
import logging
35
import typing as t
@@ -12,11 +14,13 @@
1214
from langchain.text_splitter import TextSplitter
1315
from langchain_core.documents import Document as LCDocument
1416
from langchain_core.pydantic_v1 import Field
15-
from llama_index.readers.schema import Document as LlamaindexDocument
1617

1718
from ragas.embeddings.base import BaseRagasEmbeddings
1819
from ragas.executor import Executor
1920

21+
if t.TYPE_CHECKING:
22+
from llama_index.readers.schema import Document as LlamaindexDocument
23+
2024
Embedding = t.Union[t.List[float], npt.NDArray[np.float64]]
2125
logger = logging.getLogger(__name__)
2226
rng = np.random.default_rng()

src/ragas/testset/generator.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import pandas as pd
88
from langchain_openai.chat_models import ChatOpenAI
99
from langchain_openai.embeddings import OpenAIEmbeddings
10-
from llama_index.readers.schema import Document as LlamaindexDocument
1110

1211
from ragas._analytics import TesetGenerationEvent, track
1312
from ragas.embeddings import BaseRagasEmbeddings
@@ -17,9 +16,14 @@
1716
from ragas.testset.evolutions import ComplexEvolution, CurrentNodes, DataRow
1817
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter
1918

20-
logger = logging.getLogger(__name__)
19+
if t.TYPE_CHECKING:
20+
from llama_index.readers.schema import Document as LlamaindexDocument
21+
from langchain_core.documents import Document as LCDocument
22+
2123
Distributions = t.Dict[t.Any, float]
2224

25+
logger = logging.getLogger(__name__)
26+
2327

2428
@dataclass
2529
class TestDataset:
@@ -79,12 +83,14 @@ def with_openai(
7983
docstore=docstore,
8084
)
8185

86+
# if you add any arguments to this function, make sure to add them to
87+
# generate_with_langchain_docs as well
8288
def generate_with_llamaindex_docs(
8389
self,
8490
documents: t.Sequence[LlamaindexDocument],
8591
test_size: int,
8692
distributions: Distributions = {},
87-
show_debug_logs=False,
93+
with_debugging_logs=False,
8894
):
8995
# chunk documents and add to docstore
9096
self.docstore.add_documents(
@@ -94,11 +100,34 @@ def generate_with_llamaindex_docs(
94100
return self.generate(
95101
test_size=test_size,
96102
distributions=distributions,
97-
show_debug_logs=show_debug_logs,
103+
with_debugging_logs=with_debugging_logs,
104+
)
105+
106+
# if you add any arguments to this function, make sure to add them to
107+
# generate_with_langchain_docs as well
108+
def generate_with_langchain_docs(
109+
self,
110+
documents: t.Sequence[LCDocument],
111+
test_size: int,
112+
distributions: Distributions = {},
113+
with_debugging_logs=False,
114+
):
115+
# chunk documents and add to docstore
116+
self.docstore.add_documents(
117+
[Document.from_langchain_document(doc) for doc in documents]
118+
)
119+
120+
return self.generate(
121+
test_size=test_size,
122+
distributions=distributions,
123+
with_debugging_logs=with_debugging_logs,
98124
)
99125

100126
def generate(
101-
self, test_size: int, distributions: Distributions = {}, show_debug_logs=False
127+
self,
128+
test_size: int,
129+
distributions: Distributions = {},
130+
with_debugging_logs=False,
102131
):
103132
# init filters and evolutions
104133
for evolution in distributions:
@@ -116,7 +145,7 @@ def generate(
116145
evolution.init_evolution()
117146
if evolution.evolution_filter is None:
118147
evolution.evolution_filter = EvolutionFilter(llm=self.critic_llm)
119-
if show_debug_logs:
148+
if with_debugging_logs:
120149
from ragas.utils import patch_logger
121150

122151
patch_logger("ragas.testset.evolutions", logging.DEBUG)

tests/unit/test_analytics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ def test_load_userid_from_json_file(tmp_path, monkeypatch):
9090

9191

9292
def test_testset_generation_tracking(monkeypatch):
93-
9493
import ragas._analytics as analyticsmodule
9594
from ragas._analytics import TesetGenerationEvent, track
9695
from ragas.testset.evolutions import multi_context, reasoning, simple

0 commit comments

Comments
 (0)