Skip to content

Commit 0adddeb

Browse files
authored
enabled metadata support in ragulate (#580)
* enabled metadata support in ragulate * fmt * lint
1 parent ad84239 commit 0adddeb

File tree

5 files changed

+71
-44
lines changed

5 files changed

+71
-44
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .base_dataset import BaseDataset
1+
from .base_dataset import BaseDataset, QueryItem
22
from .crag_dataset import CragDataset
33
from .llama_dataset import LlamaDataset
44
from .utils import find_dataset, get_dataset
@@ -7,6 +7,7 @@
77
"BaseDataset",
88
"CragDataset",
99
"LlamaDataset",
10+
"QueryItem",
1011
"find_dataset",
1112
"get_dataset",
1213
]

libs/ragulate/ragstack_ragulate/datasets/base_dataset.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,37 @@
33
from abc import ABC, abstractmethod
44
from os import makedirs, path
55
from pathlib import Path
6-
from typing import Dict, List, Optional, Tuple
6+
from typing import Any, Dict, List, Optional
77

88
import aiofiles
99
import aiohttp
1010
from tqdm.asyncio import tqdm
1111

1212

13+
class QueryItem():
14+
query: str
15+
metadata: Dict[str, Any]
16+
17+
def __init__(self, query:str, metadata: Dict[str, Any]):
18+
self.query = query
19+
self.metadata = metadata
20+
21+
1322
class BaseDataset(ABC):
1423

1524
root_storage_path: str
1625
name: str
1726
_subsets: List[str] = []
27+
_query_items: List[QueryItem]
28+
_golden_set: List[Dict[str, str]]
1829

1930
def __init__(
2031
self, dataset_name: str, root_storage_path: str = "datasets"
2132
):
2233
self.name = dataset_name
2334
self.root_storage_path = root_storage_path
35+
self._query_items = []
36+
self._golden_set = []
2437

2538
def storage_path(self) -> str:
2639
"""returns the path where dataset files should be stored"""
@@ -55,8 +68,20 @@ def get_source_file_paths(self) -> List[str]:
5568
"""gets a list of source file paths for for a dataset"""
5669

5770
@abstractmethod
58-
def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
59-
"""gets a list of queries and golden_truth answers for a dataset"""
71+
def _load_query_items_and_golden_set(self) -> None:
72+
"""loads query_items and golden_set"""
73+
74+
def get_query_items(self) -> List[QueryItem]:
75+
"""gets a list of query items for a dataset"""
76+
if len(self._query_items) == 0:
77+
self._load_query_items_and_golden_set()
78+
return self._query_items
79+
80+
def get_golden_set(self) -> List[Dict[str, str]]:
81+
"""gets the set of ground_truth answers for a dataset"""
82+
if len(self._golden_set) == 0:
83+
self._load_query_items_and_golden_set()
84+
return self._golden_set
6085

6186
async def _download_file(
6287
self, session: aiohttp.ClientSession, url: str, temp_file_path: str

libs/ragulate/ragstack_ragulate/datasets/crag_dataset.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import asyncio
22
import json
33
from os import path
4-
from typing import Dict, List, Optional, Tuple
4+
from typing import List, Optional
55

6-
from .base_dataset import BaseDataset
6+
from .base_dataset import BaseDataset, QueryItem
77

88

99
class CragDataset(BaseDataset):
@@ -52,10 +52,8 @@ def download_dataset(self) -> None:
5252
def get_source_file_paths(self) -> List[str]:
5353
raise NotImplementedError("Crag source files are not yet supported")
5454

55-
def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
56-
"""gets a list of queries and golden_truth answers for a dataset"""
57-
queries: List[str] = []
58-
golden_set: List[Dict[str, str]] = []
55+
def _load_query_items_and_golden_set(self) -> None:
56+
"""loads query_items and golden_set"""
5957

6058
for subset in self.subsets:
6159
if subset not in self._subset_kinds:
@@ -74,10 +72,11 @@ def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
7472

7573
query = data.get("query")
7674
answer = data.get("answer")
75+
del data["query"]
76+
del data["answer"]
7777
if query is not None and answer is not None:
78-
queries.append(query)
79-
golden_set.append({"query": query, "response": answer})
80-
81-
print(f"found {len(queries)} for subsets: {self.subsets}")
82-
83-
return queries, golden_set
78+
self._query_items.append(QueryItem(
79+
query=query,
80+
metadata=data,
81+
))
82+
self._golden_set.append({"query": query, "response": answer})

libs/ragulate/ragstack_ragulate/datasets/llama_dataset.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111

1212
from ..logging_config import logger
13-
from .base_dataset import BaseDataset
13+
from .base_dataset import BaseDataset, QueryItem
1414

1515

1616
class LlamaDataset(BaseDataset):
@@ -69,17 +69,14 @@ def get_source_file_paths(self) -> List[str]:
6969
source_path = path.join(self._get_dataset_path(), "source_files")
7070
return self.list_files_at_path(path=source_path)
7171

72-
def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]:
73-
"""gets a list of queries and golden_truth answers for a dataset"""
72+
def _load_query_items_and_golden_set(self) -> None:
73+
"""loads query_items and golden_set"""
7474
json_path = path.join(self._get_dataset_path(), "rag_dataset.json")
7575
with open(json_path, "r") as f:
7676
examples = json.load(f)["examples"]
77-
queries = [e["query"] for e in examples]
78-
golden_set = [
79-
{
80-
"query": e["query"],
81-
"response": e["reference_answer"],
82-
}
83-
for e in examples
84-
]
85-
return queries, golden_set
77+
for example in examples:
78+
self._query_items.append(QueryItem(query=example["query"], metadata={}))
79+
self._golden_set.append({
80+
"query": example["query"],
81+
"response": example["reference_answer"],
82+
})

libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .feedbacks import Feedbacks
2020

2121
if TYPE_CHECKING:
22-
from ragstack_ragulate.datasets import BaseDataset
22+
from ragstack_ragulate.datasets import BaseDataset, QueryItem
2323

2424

2525
class QueryPipeline(BasePipeline):
@@ -30,7 +30,7 @@ class QueryPipeline(BasePipeline):
3030
_tru: Tru
3131
_name: str
3232
_progress: tqdm[Never]
33-
_queries: dict[str, list[str]]
33+
_query_items: dict[str, list[QueryItem]]
3434
_golden_sets: dict[str, list[dict[str, str]]]
3535
_total_queries: int = 0
3636
_total_feedbacks: int = 0
@@ -61,7 +61,7 @@ def __init__(
6161
llm_provider: str = "OpenAI",
6262
model_name: str | None = None,
6363
):
64-
self._queries = {}
64+
self._query_items = {}
6565
self._golden_sets = {}
6666
super().__init__(
6767
recipe_name=recipe_name,
@@ -89,14 +89,14 @@ def __init__(
8989

9090
total_existing_queries = 0
9191
for dataset in datasets:
92-
queries, golden_set = dataset.get_queries_and_golden_set()
92+
query_items = dataset.get_query_items()
9393
if self.sample_percent < 1.0:
9494
if self.random_seed is not None:
9595
random.seed(self.random_seed)
9696
sampled_indices = random.sample(
97-
range(len(queries)), int(self.sample_percent * len(queries))
97+
range(len(query_items)), int(self.sample_percent * len(query_items))
9898
)
99-
queries = [queries[i] for i in sampled_indices]
99+
query_items = [query_items[i] for i in sampled_indices]
100100

101101
# Check for existing records and filter queries
102102
existing_records, _feedbacks = self._tru.get_records_and_feedback(
@@ -105,11 +105,15 @@ def __init__(
105105
existing_queries = existing_records["input"].dropna().tolist()
106106
total_existing_queries += len(existing_queries)
107107

108-
queries = [query for query in queries if query not in existing_queries]
108+
query_items = [
109+
query_item
110+
for query_item in query_items
111+
if query_item.query not in existing_queries
112+
]
109113

110-
self._queries[dataset.name] = queries
111-
self._golden_sets[dataset.name] = golden_set
112-
self._total_queries += len(self._queries[dataset.name])
114+
self._query_items[dataset.name] = query_items
115+
self._golden_sets[dataset.name] = dataset.get_golden_set()
116+
self._total_queries += len(self._query_items[dataset.name])
113117

114118
metric_count = 4
115119
self._total_feedbacks = self._total_queries * metric_count
@@ -129,7 +133,7 @@ def start_evaluation(self) -> None:
129133

130134
def export_results(self) -> None:
131135
"""Export results."""
132-
for dataset_name in self._queries:
136+
for dataset_name in self._query_items:
133137
records, _feedback_names = self._tru.get_records_and_feedback(
134138
app_ids=[dataset_name]
135139
)
@@ -217,7 +221,7 @@ def query(self) -> None:
217221
initial=self._finished_queries,
218222
)
219223

220-
for dataset_name in self._queries:
224+
for dataset_name in self._query_items:
221225
feedback_functions = [
222226
feedbacks.answer_correctness(
223227
golden_set=self._golden_sets[dataset_name]
@@ -234,14 +238,15 @@ def query(self) -> None:
234238
feedback_mode=FeedbackMode.DEFERRED,
235239
)
236240

237-
for query in self._queries[dataset_name]:
241+
for query_item in self._query_items[dataset_name]:
238242
if self._sigint_received:
239243
break
240244
try:
241-
with recorder:
242-
pipeline.invoke(query)
245+
with recorder as recording:
246+
recording.record_metadata = query_item.metadata
247+
pipeline.invoke(query_item.query)
243248
except Exception as e: # noqa: BLE001
244-
err = f"Query: '{query}' caused exception, skipping."
249+
err = f"Query: '{query_item.query}' caused exception, skipping."
245250
logger.exception(err)
246251
# TODO: figure out why the logger isn't working after tru-lens starts. For now use print(). # noqa: E501
247252
print(f"{err} Exception {e}")

0 commit comments

Comments
 (0)