Skip to content

Commit 3d29c44

Browse files
Make testgenerator output compatible with evaluate (#302)
* Changed context from str to List[str] so that it is consistent with eval. Now output of TestDataset can be used for evaluation. * Changed typo in _generate_doc_nodes_map * Changed TestDataset class to reflect the changes in test set generation. Drawback is episode_done will be True in all cases as data is changed at the level above.
1 parent 1cfbaa1 commit 3d29c44

File tree

1 file changed

+17
-21
lines changed

1 file changed

+17
-21
lines changed

src/ragas/testset/testset_generator.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"conditional": "_condition_question",
5858
}
5959

60-
DataRow = namedtuple("DataRow", ["question", "context", "answer", "question_type"])
60+
DataRow = namedtuple("DataRow", ["question", "ground_truth_context", "ground_truth", "question_type", "episode_done"])
6161

6262

6363
@dataclass
@@ -71,21 +71,14 @@ class TestDataset:
7171
def to_pandas(self) -> pd.DataFrame:
7272
data_samples = []
7373
for data in self.test_data:
74-
is_conv = len(data.context) > 1
75-
question_type = data.question_type
76-
data = [
77-
{
78-
"question": qstn,
79-
"context": ctx,
80-
"answer": ans,
81-
"question_type": question_type,
82-
"episode_done": True,
83-
}
84-
for qstn, ctx, ans in zip(data.question, data.context, data.answer)
85-
]
86-
if is_conv:
87-
data[0].update({"episode_done": False})
88-
data_samples.extend(data)
74+
data = {
75+
"question": data.question,
76+
"ground_truth_context": data.ground_truth_context,
77+
"ground_truth": data.ground_truth,
78+
"question_type": data.question_type,
79+
"episode_done": data.episode_done,
80+
}
81+
data_samples.append(data)
8982

9083
return pd.DataFrame.from_records(data_samples)
9184

@@ -260,10 +253,10 @@ def _remove_nodes(
260253
return available_indices
261254

262255
def _generate_doc_nodes_map(
263-
self, documenet_nodes: t.List[BaseNode]
256+
self, document_nodes: t.List[BaseNode]
264257
) -> t.Dict[str, t.List[BaseNode]]:
265258
doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list)
266-
for node in documenet_nodes:
259+
for node in document_nodes:
267260
if node.ref_doc_id:
268261
doc_nodes_map[node.ref_doc_id].append(node)
269262

@@ -398,10 +391,13 @@ def generate(
398391
is_valid_question = self._filter_question(question)
399392
if is_valid_question:
400393
context = self._generate_context(question, text_chunk)
394+
is_conv = len(context) > 1
401395
answer = self._generate_answer(question, context)
402-
samples.append(
403-
DataRow(question.split("\n"), context, answer, evolve_type)
404-
)
396+
for i, (qstn, ctx, ans) in enumerate(zip(question.split("\n"), context, answer)):
397+
episode_done = False if is_conv and i==0 else True
398+
samples.append(
399+
DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
400+
)
405401
count += 1
406402
pbar.update(count)
407403

0 commit comments

Comments
 (0)