|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import json |
| 4 | +import random |
4 | 5 | import typing as t |
5 | 6 | from abc import ABC, abstractmethod |
| 7 | +from collections import defaultdict |
6 | 8 | from dataclasses import dataclass, field |
7 | 9 |
|
| 10 | +import numpy as np |
8 | 11 | from datasets import Dataset as HFDataset |
9 | 12 | from pydantic import BaseModel, field_validator |
10 | 13 |
|
@@ -526,3 +529,260 @@ def upload(self, base_url: str = RAGAS_API_URL, verbose: bool = True) -> str: |
526 | 529 | if verbose: |
527 | 530 | print(f"Evaluation results uploaded! View at {evaluation_endpoint}") |
528 | 531 | return evaluation_endpoint |
| 532 | + |
| 533 | + |
| 534 | +class PromptAnnotation(BaseModel): |
| 535 | + prompt_input: t.Dict[str, t.Any] |
| 536 | + prompt_output: t.Dict[str, t.Any] |
| 537 | + is_accepted: bool |
| 538 | + edited_output: t.Union[t.Dict[str, t.Any], None] |
| 539 | + |
| 540 | + def __getitem__(self, key): |
| 541 | + return getattr(self, key) |
| 542 | + |
| 543 | + |
| 544 | +class SampleAnnotation(BaseModel): |
| 545 | + metric_input: t.Dict[str, t.Any] |
| 546 | + metric_output: float |
| 547 | + prompts: t.Dict[str, PromptAnnotation] |
| 548 | + is_accepted: bool |
| 549 | + target: t.Optional[float] = None |
| 550 | + |
| 551 | + def __getitem__(self, key): |
| 552 | + return getattr(self, key) |
| 553 | + |
| 554 | + |
| 555 | +class MetricAnnotation(BaseModel): |
| 556 | + |
| 557 | + root: t.Dict[str, t.List[SampleAnnotation]] |
| 558 | + |
| 559 | + def __getitem__(self, key): |
| 560 | + return SingleMetricAnnotation(name=key, samples=self.root[key]) |
| 561 | + |
| 562 | + @classmethod |
| 563 | + def from_json(cls, path, metric_name: t.Optional[str]) -> "MetricAnnotation": |
| 564 | + |
| 565 | + dataset = json.load(open(path)) |
| 566 | + if metric_name is not None and metric_name not in dataset: |
| 567 | + raise ValueError(f"Split {metric_name} not found in the dataset.") |
| 568 | + |
| 569 | + return cls( |
| 570 | + root={ |
| 571 | + key: [SampleAnnotation(**sample) for sample in value] |
| 572 | + for key, value in dataset.items() |
| 573 | + if metric_name is None or key == metric_name |
| 574 | + } |
| 575 | + ) |
| 576 | + |
| 577 | + def __len__(self): |
| 578 | + return sum(len(value) for value in self.root.values()) |
| 579 | + |
| 580 | + |
| 581 | +class SingleMetricAnnotation(BaseModel): |
| 582 | + name: str |
| 583 | + samples: t.List[SampleAnnotation] |
| 584 | + |
| 585 | + def to_evaluation_dataset(self) -> EvaluationDataset: |
| 586 | + samples = [sample.metric_input for sample in self.samples] |
| 587 | + return EvaluationDataset.from_list(samples) |
| 588 | + |
| 589 | + def __getitem__(self, idx): |
| 590 | + return self.samples[idx] |
| 591 | + |
| 592 | + def __repr__(self): |
| 593 | + return f"SingleMetricAnnotation(name={self.name}, len={len(self.samples)})" |
| 594 | + |
| 595 | + def __iter__(self) -> t.Iterator[SampleAnnotation]: # type: ignore |
| 596 | + return iter(self.samples) |
| 597 | + |
| 598 | + def select(self, indices: t.List[int]) -> "SingleMetricAnnotation": |
| 599 | + return SingleMetricAnnotation( |
| 600 | + name=self.name, |
| 601 | + samples=[self.samples[idx] for idx in indices], |
| 602 | + ) |
| 603 | + |
| 604 | + @classmethod |
| 605 | + def from_json(cls, path) -> "SingleMetricAnnotation": |
| 606 | + |
| 607 | + dataset = json.load(open(path)) |
| 608 | + |
| 609 | + return cls( |
| 610 | + name=dataset["name"], |
| 611 | + samples=[SampleAnnotation(**sample) for sample in dataset["samples"]], |
| 612 | + ) |
| 613 | + |
| 614 | + def filter(self, function: t.Optional[t.Callable] = None): |
| 615 | + |
| 616 | + if function is None: |
| 617 | + function = lambda x: True # noqa: E731 |
| 618 | + |
| 619 | + return SingleMetricAnnotation( |
| 620 | + name=self.name, |
| 621 | + samples=[sample for sample in self.samples if function(sample)], |
| 622 | + ) |
| 623 | + |
| 624 | + def __len__(self): |
| 625 | + return len(self.samples) |
| 626 | + |
| 627 | + def train_test_split( |
| 628 | + self, |
| 629 | + test_size: float = 0.2, |
| 630 | + seed: int = 42, |
| 631 | + stratify: t.Optional[t.List[t.Any]] = None, |
| 632 | + ) -> t.Tuple["SingleMetricAnnotation", "SingleMetricAnnotation"]: |
| 633 | + """ |
| 634 | + Split the dataset into training and testing sets. |
| 635 | +
|
| 636 | + Parameters: |
| 637 | + test_size (float): The proportion of the dataset to include in the test split. |
| 638 | + seed (int): Random seed for reproducibility. |
| 639 | + stratify (list): The column values to stratify the split on. |
| 640 | + """ |
| 641 | + raise NotImplementedError |
| 642 | + |
| 643 | + def sample( |
| 644 | + self, n: int, stratify_key: t.Optional[str] = None |
| 645 | + ) -> "SingleMetricAnnotation": |
| 646 | + """ |
| 647 | + Create a subset of the dataset. |
| 648 | +
|
| 649 | + Parameters: |
| 650 | + n (int): The number of samples to include in the subset. |
| 651 | + stratify_key (str): The column to stratify the subset on. |
| 652 | +
|
| 653 | + Returns: |
| 654 | + SingleMetricAnnotation: A subset of the dataset with `n` samples. |
| 655 | + """ |
| 656 | + if n > len(self.samples): |
| 657 | + raise ValueError( |
| 658 | + "Requested sample size exceeds the number of available samples." |
| 659 | + ) |
| 660 | + |
| 661 | + if stratify_key is None: |
| 662 | + # Simple random sampling |
| 663 | + sampled_indices = random.sample(range(len(self.samples)), n) |
| 664 | + sampled_samples = [self.samples[i] for i in sampled_indices] |
| 665 | + else: |
| 666 | + # Stratified sampling |
| 667 | + class_groups = defaultdict(list) |
| 668 | + for idx, sample in enumerate(self.samples): |
| 669 | + key = sample[stratify_key] |
| 670 | + class_groups[key].append(idx) |
| 671 | + |
| 672 | + # Determine the proportion of samples to take from each class |
| 673 | + total_samples = sum(len(indices) for indices in class_groups.values()) |
| 674 | + proportions = { |
| 675 | + cls: len(indices) / total_samples |
| 676 | + for cls, indices in class_groups.items() |
| 677 | + } |
| 678 | + |
| 679 | + sampled_indices = [] |
| 680 | + for cls, indices in class_groups.items(): |
| 681 | + cls_sample_count = int(np.round(proportions[cls] * n)) |
| 682 | + cls_sample_count = min( |
| 683 | + cls_sample_count, len(indices) |
| 684 | + ) # Don't oversample |
| 685 | + sampled_indices.extend(random.sample(indices, cls_sample_count)) |
| 686 | + |
| 687 | + # Handle any rounding discrepancies to ensure exactly `n` samples |
| 688 | + while len(sampled_indices) < n: |
| 689 | + remaining_indices = set(range(len(self.samples))) - set(sampled_indices) |
| 690 | + if not remaining_indices: |
| 691 | + break |
| 692 | + sampled_indices.append(random.choice(list(remaining_indices))) |
| 693 | + |
| 694 | + sampled_samples = [self.samples[i] for i in sampled_indices] |
| 695 | + |
| 696 | + return SingleMetricAnnotation(name=self.name, samples=sampled_samples) |
| 697 | + |
| 698 | + def batch( |
| 699 | + self, |
| 700 | + batch_size: int, |
| 701 | + drop_last_batch: bool = False, |
| 702 | + ): |
| 703 | + """ |
| 704 | + Create a batch iterator. |
| 705 | +
|
| 706 | + Parameters: |
| 707 | + batch_size (int): The number of samples in each batch. |
| 708 | + stratify (str): The column to stratify the batches on. |
| 709 | + drop_last_batch (bool): Whether to drop the last batch if it is smaller than the specified batch size. |
| 710 | + """ |
| 711 | + |
| 712 | + samples = self.samples[:] |
| 713 | + random.shuffle(samples) |
| 714 | + |
| 715 | + all_batches = [ |
| 716 | + samples[i : i + batch_size] |
| 717 | + for i in range(0, len(samples), batch_size) |
| 718 | + if len(samples[i : i + batch_size]) == batch_size or not drop_last_batch |
| 719 | + ] |
| 720 | + |
| 721 | + return all_batches |
| 722 | + |
| 723 | + def stratified_batches( |
| 724 | + self, |
| 725 | + batch_size: int, |
| 726 | + stratify_key: str, |
| 727 | + drop_last_batch: bool = False, |
| 728 | + replace: bool = False, |
| 729 | + ) -> t.List[t.List[SampleAnnotation]]: |
| 730 | + """ |
| 731 | + Create stratified batches based on a specified key, ensuring proportional representation. |
| 732 | +
|
| 733 | + Parameters: |
| 734 | + batch_size (int): Number of samples per batch. |
| 735 | + stratify_key (str): Key in `metric_input` used for stratification (e.g., class labels). |
| 736 | + drop_last_batch (bool): If True, drops the last batch if it has fewer samples than `batch_size`. |
| 737 | + replace (bool): If True, allows reusing samples from the same class to fill a batch if necessary. |
| 738 | +
|
| 739 | + Returns: |
| 740 | + List[List[SampleAnnotation]]: A list of stratified batches, each batch being a list of SampleAnnotation objects. |
| 741 | + """ |
| 742 | + # Group samples based on the stratification key |
| 743 | + class_groups = defaultdict(list) |
| 744 | + for sample in self.samples: |
| 745 | + key = sample[stratify_key] |
| 746 | + class_groups[key].append(sample) |
| 747 | + |
| 748 | + # Shuffle each class group for randomness |
| 749 | + for group in class_groups.values(): |
| 750 | + random.shuffle(group) |
| 751 | + |
| 752 | + # Determine the number of batches required |
| 753 | + total_samples = len(self.samples) |
| 754 | + num_batches = ( |
| 755 | + np.ceil(total_samples / batch_size).astype(int) |
| 756 | + if drop_last_batch |
| 757 | + else np.floor(total_samples / batch_size).astype(int) |
| 758 | + ) |
| 759 | + samples_per_class_per_batch = { |
| 760 | + cls: max(1, len(samples) // num_batches) |
| 761 | + for cls, samples in class_groups.items() |
| 762 | + } |
| 763 | + |
| 764 | + # Create stratified batches |
| 765 | + all_batches = [] |
| 766 | + while len(all_batches) < num_batches: |
| 767 | + batch = [] |
| 768 | + for cls, samples in list(class_groups.items()): |
| 769 | + # Determine the number of samples to take from this class |
| 770 | + count = min( |
| 771 | + samples_per_class_per_batch[cls], |
| 772 | + len(samples), |
| 773 | + batch_size - len(batch), |
| 774 | + ) |
| 775 | + if count > 0: |
| 776 | + # Add samples from the current class |
| 777 | + batch.extend(samples[:count]) |
| 778 | + class_groups[cls] = samples[count:] # Remove used samples |
| 779 | + elif replace and len(batch) < batch_size: |
| 780 | + # Reuse samples if `replace` is True |
| 781 | + batch.extend(random.choices(samples, k=batch_size - len(batch))) |
| 782 | + |
| 783 | + # Shuffle the batch to mix classes |
| 784 | + random.shuffle(batch) |
| 785 | + if len(batch) == batch_size or not drop_last_batch: |
| 786 | + all_batches.append(batch) |
| 787 | + |
| 788 | + return all_batches |
0 commit comments