Skip to content

Commit 643ab66

Browse files
authored
feat: dataloader for annotated json (#1723)
1 parent 9bd1402 commit 643ab66

File tree

2 files changed

+315
-0
lines changed

2 files changed

+315
-0
lines changed

src/ragas/dataset_schema.py

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
import json
4+
import random
45
import typing as t
56
from abc import ABC, abstractmethod
7+
from collections import defaultdict
68
from dataclasses import dataclass, field
79

10+
import numpy as np
811
from datasets import Dataset as HFDataset
912
from pydantic import BaseModel, field_validator
1013

@@ -526,3 +529,260 @@ def upload(self, base_url: str = RAGAS_API_URL, verbose: bool = True) -> str:
526529
if verbose:
527530
print(f"Evaluation results uploaded! View at {evaluation_endpoint}")
528531
return evaluation_endpoint
532+
533+
534+
class PromptAnnotation(BaseModel):
535+
prompt_input: t.Dict[str, t.Any]
536+
prompt_output: t.Dict[str, t.Any]
537+
is_accepted: bool
538+
edited_output: t.Union[t.Dict[str, t.Any], None]
539+
540+
def __getitem__(self, key):
541+
return getattr(self, key)
542+
543+
544+
class SampleAnnotation(BaseModel):
545+
metric_input: t.Dict[str, t.Any]
546+
metric_output: float
547+
prompts: t.Dict[str, PromptAnnotation]
548+
is_accepted: bool
549+
target: t.Optional[float] = None
550+
551+
def __getitem__(self, key):
552+
return getattr(self, key)
553+
554+
555+
class MetricAnnotation(BaseModel):
556+
557+
root: t.Dict[str, t.List[SampleAnnotation]]
558+
559+
def __getitem__(self, key):
560+
return SingleMetricAnnotation(name=key, samples=self.root[key])
561+
562+
@classmethod
563+
def from_json(cls, path, metric_name: t.Optional[str]) -> "MetricAnnotation":
564+
565+
dataset = json.load(open(path))
566+
if metric_name is not None and metric_name not in dataset:
567+
raise ValueError(f"Split {metric_name} not found in the dataset.")
568+
569+
return cls(
570+
root={
571+
key: [SampleAnnotation(**sample) for sample in value]
572+
for key, value in dataset.items()
573+
if metric_name is None or key == metric_name
574+
}
575+
)
576+
577+
def __len__(self):
578+
return sum(len(value) for value in self.root.values())
579+
580+
581+
class SingleMetricAnnotation(BaseModel):
582+
name: str
583+
samples: t.List[SampleAnnotation]
584+
585+
def to_evaluation_dataset(self) -> EvaluationDataset:
586+
samples = [sample.metric_input for sample in self.samples]
587+
return EvaluationDataset.from_list(samples)
588+
589+
def __getitem__(self, idx):
590+
return self.samples[idx]
591+
592+
def __repr__(self):
593+
return f"SingleMetricAnnotation(name={self.name}, len={len(self.samples)})"
594+
595+
def __iter__(self) -> t.Iterator[SampleAnnotation]: # type: ignore
596+
return iter(self.samples)
597+
598+
def select(self, indices: t.List[int]) -> "SingleMetricAnnotation":
599+
return SingleMetricAnnotation(
600+
name=self.name,
601+
samples=[self.samples[idx] for idx in indices],
602+
)
603+
604+
@classmethod
605+
def from_json(cls, path) -> "SingleMetricAnnotation":
606+
607+
dataset = json.load(open(path))
608+
609+
return cls(
610+
name=dataset["name"],
611+
samples=[SampleAnnotation(**sample) for sample in dataset["samples"]],
612+
)
613+
614+
def filter(self, function: t.Optional[t.Callable] = None):
615+
616+
if function is None:
617+
function = lambda x: True # noqa: E731
618+
619+
return SingleMetricAnnotation(
620+
name=self.name,
621+
samples=[sample for sample in self.samples if function(sample)],
622+
)
623+
624+
def __len__(self):
625+
return len(self.samples)
626+
627+
def train_test_split(
628+
self,
629+
test_size: float = 0.2,
630+
seed: int = 42,
631+
stratify: t.Optional[t.List[t.Any]] = None,
632+
) -> t.Tuple["SingleMetricAnnotation", "SingleMetricAnnotation"]:
633+
"""
634+
Split the dataset into training and testing sets.
635+
636+
Parameters:
637+
test_size (float): The proportion of the dataset to include in the test split.
638+
seed (int): Random seed for reproducibility.
639+
stratify (list): The column values to stratify the split on.
640+
"""
641+
raise NotImplementedError
642+
643+
def sample(
644+
self, n: int, stratify_key: t.Optional[str] = None
645+
) -> "SingleMetricAnnotation":
646+
"""
647+
Create a subset of the dataset.
648+
649+
Parameters:
650+
n (int): The number of samples to include in the subset.
651+
stratify_key (str): The column to stratify the subset on.
652+
653+
Returns:
654+
SingleMetricAnnotation: A subset of the dataset with `n` samples.
655+
"""
656+
if n > len(self.samples):
657+
raise ValueError(
658+
"Requested sample size exceeds the number of available samples."
659+
)
660+
661+
if stratify_key is None:
662+
# Simple random sampling
663+
sampled_indices = random.sample(range(len(self.samples)), n)
664+
sampled_samples = [self.samples[i] for i in sampled_indices]
665+
else:
666+
# Stratified sampling
667+
class_groups = defaultdict(list)
668+
for idx, sample in enumerate(self.samples):
669+
key = sample[stratify_key]
670+
class_groups[key].append(idx)
671+
672+
# Determine the proportion of samples to take from each class
673+
total_samples = sum(len(indices) for indices in class_groups.values())
674+
proportions = {
675+
cls: len(indices) / total_samples
676+
for cls, indices in class_groups.items()
677+
}
678+
679+
sampled_indices = []
680+
for cls, indices in class_groups.items():
681+
cls_sample_count = int(np.round(proportions[cls] * n))
682+
cls_sample_count = min(
683+
cls_sample_count, len(indices)
684+
) # Don't oversample
685+
sampled_indices.extend(random.sample(indices, cls_sample_count))
686+
687+
# Handle any rounding discrepancies to ensure exactly `n` samples
688+
while len(sampled_indices) < n:
689+
remaining_indices = set(range(len(self.samples))) - set(sampled_indices)
690+
if not remaining_indices:
691+
break
692+
sampled_indices.append(random.choice(list(remaining_indices)))
693+
694+
sampled_samples = [self.samples[i] for i in sampled_indices]
695+
696+
return SingleMetricAnnotation(name=self.name, samples=sampled_samples)
697+
698+
def batch(
699+
self,
700+
batch_size: int,
701+
drop_last_batch: bool = False,
702+
):
703+
"""
704+
Create a batch iterator.
705+
706+
Parameters:
707+
batch_size (int): The number of samples in each batch.
708+
stratify (str): The column to stratify the batches on.
709+
drop_last_batch (bool): Whether to drop the last batch if it is smaller than the specified batch size.
710+
"""
711+
712+
samples = self.samples[:]
713+
random.shuffle(samples)
714+
715+
all_batches = [
716+
samples[i : i + batch_size]
717+
for i in range(0, len(samples), batch_size)
718+
if len(samples[i : i + batch_size]) == batch_size or not drop_last_batch
719+
]
720+
721+
return all_batches
722+
723+
def stratified_batches(
724+
self,
725+
batch_size: int,
726+
stratify_key: str,
727+
drop_last_batch: bool = False,
728+
replace: bool = False,
729+
) -> t.List[t.List[SampleAnnotation]]:
730+
"""
731+
Create stratified batches based on a specified key, ensuring proportional representation.
732+
733+
Parameters:
734+
batch_size (int): Number of samples per batch.
735+
stratify_key (str): Key in `metric_input` used for stratification (e.g., class labels).
736+
drop_last_batch (bool): If True, drops the last batch if it has fewer samples than `batch_size`.
737+
replace (bool): If True, allows reusing samples from the same class to fill a batch if necessary.
738+
739+
Returns:
740+
List[List[SampleAnnotation]]: A list of stratified batches, each batch being a list of SampleAnnotation objects.
741+
"""
742+
# Group samples based on the stratification key
743+
class_groups = defaultdict(list)
744+
for sample in self.samples:
745+
key = sample[stratify_key]
746+
class_groups[key].append(sample)
747+
748+
# Shuffle each class group for randomness
749+
for group in class_groups.values():
750+
random.shuffle(group)
751+
752+
# Determine the number of batches required
753+
total_samples = len(self.samples)
754+
num_batches = (
755+
np.ceil(total_samples / batch_size).astype(int)
756+
if drop_last_batch
757+
else np.floor(total_samples / batch_size).astype(int)
758+
)
759+
samples_per_class_per_batch = {
760+
cls: max(1, len(samples) // num_batches)
761+
for cls, samples in class_groups.items()
762+
}
763+
764+
# Create stratified batches
765+
all_batches = []
766+
while len(all_batches) < num_batches:
767+
batch = []
768+
for cls, samples in list(class_groups.items()):
769+
# Determine the number of samples to take from this class
770+
count = min(
771+
samples_per_class_per_batch[cls],
772+
len(samples),
773+
batch_size - len(batch),
774+
)
775+
if count > 0:
776+
# Add samples from the current class
777+
batch.extend(samples[:count])
778+
class_groups[cls] = samples[count:] # Remove used samples
779+
elif replace and len(batch) < batch_size:
780+
# Reuse samples if `replace` is True
781+
batch.extend(random.choices(samples, k=batch_size - len(batch)))
782+
783+
# Shuffle the batch to mix classes
784+
random.shuffle(batch)
785+
if len(batch) == batch_size or not drop_last_batch:
786+
all_batches.append(batch)
787+
788+
return all_batches

tests/unit/test_dataset_schema.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
EvaluationDataset,
77
HumanMessage,
88
MultiTurnSample,
9+
PromptAnnotation,
10+
SampleAnnotation,
11+
SingleMetricAnnotation,
912
SingleTurnSample,
1013
)
1114

@@ -18,6 +21,58 @@
1821
]
1922

2023

24+
def create_sample_annotation(metric_output):
25+
return SampleAnnotation(
26+
metric_input={
27+
"response": "",
28+
"reference": "",
29+
"user_input": "",
30+
},
31+
metric_output=metric_output,
32+
prompts={
33+
"single_turn_aspect_critic_prompt": PromptAnnotation(
34+
prompt_input={
35+
"response": "",
36+
"reference": "",
37+
"user_input": "",
38+
},
39+
prompt_output={"reason": "", "verdict": 1},
40+
is_accepted=True,
41+
edited_output=None,
42+
)
43+
},
44+
is_accepted=True,
45+
target=None,
46+
)
47+
48+
49+
def test_loader_sample():
50+
51+
annotated_samples = [create_sample_annotation(1) for _ in range(10)] + [
52+
create_sample_annotation(0) for _ in range(10)
53+
]
54+
test_dataset = SingleMetricAnnotation(name="metric", samples=annotated_samples)
55+
sample = test_dataset.sample(2)
56+
assert len(sample) == 2
57+
58+
sample = test_dataset.sample(2, stratify_key="metric_output")
59+
assert len(sample) == 2
60+
assert sum(item["metric_output"] for item in sample) == 1
61+
62+
63+
def test_loader_batch():
64+
65+
annotated_samples = [create_sample_annotation(1) for _ in range(10)] + [
66+
create_sample_annotation(0) for _ in range(10)
67+
]
68+
dataset = SingleMetricAnnotation(name="metric", samples=annotated_samples)
69+
batches = dataset.batch(batch_size=2)
70+
assert all([len(item) == 2 for item in batches])
71+
72+
batches = dataset.stratified_batches(batch_size=2, stratify_key="metric_output")
73+
assert all(sum([item["metric_output"] for item in batch]) == 1 for batch in batches)
74+
75+
2176
@pytest.mark.parametrize("eval_sample", samples)
2277
def test_evaluation_dataset(eval_sample):
2378
dataset = EvaluationDataset(samples=[eval_sample, eval_sample])

0 commit comments

Comments
 (0)