Skip to content

Commit d73012f

Browse files
committed
simplify
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 25e26ce commit d73012f

File tree

2 files changed

+40
-241
lines changed

2 files changed

+40
-241
lines changed

nemo_automodel/components/datasets/llm/retrieval_dataset_inline.py

Lines changed: 19 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -15,144 +15,13 @@
1515
import json
1616
import logging
1717
import os
18-
from abc import ABC, abstractmethod
19-
from copy import deepcopy
20-
from dataclasses import dataclass
21-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, Dict, List, Union
2219

23-
from datasets import Dataset, concatenate_datasets, load_dataset
20+
from datasets import Dataset, concatenate_datasets
2421

25-
EXAMPLE_TEMPLATE = {"text": "", "image": "", "nr_ocr": ""}
2622
INLINE_CORPUS_ID = "__inline__"
2723

2824

29-
class AbstractDataset(ABC):
30-
@abstractmethod
31-
def get_document_by_id(self, id):
32-
pass
33-
34-
@abstractmethod
35-
def get_all_ids(self):
36-
pass
37-
38-
39-
class TextQADataset(AbstractDataset):
40-
def __init__(self, path):
41-
self.path = path
42-
self.data = load_dataset(path)["train"]
43-
docid2idx = {}
44-
for idx, docid in enumerate(self.data["id"]):
45-
docid2idx[str(docid)] = idx
46-
self.docid2idx = docid2idx
47-
48-
def get_document_by_id(self, id):
49-
example = deepcopy(EXAMPLE_TEMPLATE)
50-
example["text"] = self.data[self.docid2idx[id]]["text"]
51-
return example
52-
53-
def get_all_ids(self):
54-
return sorted(list(self.docid2idx.keys()))
55-
56-
57-
DATASETS = {
58-
"TextQADataset": TextQADataset,
59-
}
60-
61-
62-
@dataclass
63-
class CorpusInfo:
64-
"""
65-
Data structure to hold corpus metadata and dataset object together.
66-
Provides easy access to both components with descriptive attribute names.
67-
"""
68-
69-
metadata: dict
70-
corpus: AbstractDataset
71-
72-
@property
73-
def corpus_id(self) -> str:
74-
"""Get corpus ID from metadata"""
75-
return self.metadata["corpus_id"]
76-
77-
@property
78-
def query_instruction(self) -> str:
79-
"""Get query instruction from metadata"""
80-
if "query_instruction" in self.metadata:
81-
return self.metadata["query_instruction"]
82-
else:
83-
return ""
84-
85-
@property
86-
def passage_instruction(self) -> str:
87-
"""Get passage instruction from metadata"""
88-
if "passage_instruction" in self.metadata:
89-
return self.metadata["passage_instruction"]
90-
else:
91-
return ""
92-
93-
@property
94-
def task_type(self) -> str:
95-
"""Get task type from metadata"""
96-
if "task_type" in self.metadata:
97-
return self.metadata["task_type"]
98-
else:
99-
return ""
100-
101-
@property
102-
def path(self) -> str:
103-
"""Get corpus path from the corpus object"""
104-
return self.corpus.path
105-
106-
def get_document_by_id(self, doc_id: str):
107-
"""Delegate to corpus for convenience"""
108-
return self.corpus.get_document_by_id(doc_id)
109-
110-
def get_all_ids(self):
111-
"""Delegate to corpus for convenience"""
112-
return self.corpus.get_all_ids()
113-
114-
115-
def load_corpus_metadata(path: str):
116-
path_metadata = os.path.join(path, "merlin_metadata.json")
117-
if not os.path.isfile(path_metadata):
118-
raise ValueError("Metadata File for Corpus does not exist: " + path_metadata)
119-
120-
metadata = json.load(open(path_metadata, "r"))
121-
return metadata
122-
123-
124-
def load_corpus(path, metadata: Optional[dict] = None):
125-
if metadata is None:
126-
metadata = load_corpus_metadata(path)
127-
if metadata["class"] not in DATASETS:
128-
raise ValueError("DatasetClass is not implemented: " + metadata["class"])
129-
corpus = DATASETS[metadata["class"]](path)
130-
corpus_id = metadata["corpus_id"]
131-
return (corpus_id, corpus)
132-
133-
134-
def add_corpus(qa_corpus_paths: Union[dict, list], corpus_dict: dict):
135-
if corpus_dict is None:
136-
raise ValueError("Corpus dictionary is not provided")
137-
if not isinstance(qa_corpus_paths, list):
138-
qa_corpus_paths = [qa_corpus_paths]
139-
140-
for corpus_info in qa_corpus_paths:
141-
corpus_metadata = load_corpus_metadata(corpus_info["path"])
142-
if corpus_metadata["corpus_id"] in corpus_dict:
143-
if corpus_dict[corpus_metadata["corpus_id"]].path != corpus_info["path"]:
144-
raise ValueError(
145-
"Two Different Datasets have the same corpus id but different paths: "
146-
+ "1. "
147-
+ corpus_dict[corpus_metadata["corpus_id"]].path
148-
+ "2. "
149-
+ corpus_info["path"]
150-
)
151-
else:
152-
corpus_id, corpus = load_corpus(corpus_info["path"], corpus_metadata)
153-
corpus_dict[corpus_id] = CorpusInfo(corpus_metadata, corpus)
154-
155-
15625
def _load_json_or_jsonl(path: str) -> Union[dict, list]:
15726
"""Load a JSON file, falling back to JSONL (one JSON object per line)."""
15827
with open(path, "r") as f:
@@ -183,16 +52,6 @@ def _coerce_to_list(value: Any) -> list:
18352
return [value]
18453

18554

186-
def _normalize_id_doc(doc: Any) -> Dict[str, Any]:
187-
"""Normalize a corpus-id based doc reference into a canonical dict shape."""
188-
if isinstance(doc, dict) and "id" in doc:
189-
doc_id = doc["id"]
190-
else:
191-
doc_id = doc
192-
doc_id = doc_id if isinstance(doc_id, str) else str(doc_id)
193-
return {"id": doc_id, "text": "", "image": "", "nr_ocr": ""}
194-
195-
19655
def _normalize_inline_doc(doc: Any) -> Dict[str, Any]:
19756
"""Normalize an inline doc (text/image provided) into a canonical dict shape."""
19857
if isinstance(doc, dict):
@@ -213,24 +72,19 @@ def _normalize_inline_doc(doc: Any) -> Dict[str, Any]:
21372
}
21473

21574

216-
def _resolve_doc_to_example(doc: Any, corpus_id: str, corpus_dict: Dict[str, Any]) -> dict:
75+
def _resolve_doc_to_example(doc: Any) -> dict:
21776
"""
21877
Resolve a doc reference into an example dict with keys: text, image, nr_ocr.
21978
220-
Supports:
221-
- corpus-id based docs: {"id": "..."} (looked up in corpus_dict[corpus_id])
222-
- inline docs: {"text": "...", "image": "", "nr_ocr": ""} (used directly)
79+
Supported doc forms:
80+
- `str`: interpreted as inline document text
81+
- `dict`: must include `text` (optionally `image`, `nr_ocr`)
22382
"""
83+
example = {"text": "", "image": "", "nr_ocr": ""}
22484
if isinstance(doc, dict):
225-
doc_id = doc.get("id", "")
226-
# Treat non-empty "id" as a corpus lookup.
227-
if doc_id:
228-
if corpus_id not in corpus_dict:
229-
raise KeyError(f"Corpus '{corpus_id}' not found in corpus_dict (needed to resolve doc id '{doc_id}').")
230-
return corpus_dict[corpus_id].get_document_by_id(str(doc_id))
231-
232-
# Inline doc: copy supported fields over the template.
233-
example = deepcopy(EXAMPLE_TEMPLATE)
85+
if "text" not in doc:
86+
raise ValueError(f"Inline doc dict must include 'text'. Got keys: {sorted(list(doc.keys()))}")
87+
23488
if "text" in doc and doc["text"] is not None:
23589
example["text"] = str(doc["text"])
23690
if "image" in doc and doc["image"] is not None:
@@ -239,16 +93,11 @@ def _resolve_doc_to_example(doc: Any, corpus_id: str, corpus_dict: Dict[str, Any
23993
example["nr_ocr"] = str(doc["nr_ocr"])
24094
return example
24195

242-
# String docs are interpreted as ids only when a corpus is available; otherwise as inline text.
24396
if isinstance(doc, str):
244-
if corpus_id in corpus_dict:
245-
return corpus_dict[corpus_id].get_document_by_id(doc)
246-
example = deepcopy(EXAMPLE_TEMPLATE)
24797
example["text"] = doc
24898
return example
24999

250100
# Fallback: coerce to string text
251-
example = deepcopy(EXAMPLE_TEMPLATE)
252101
example["text"] = str(doc)
253102
return example
254103

@@ -264,39 +113,19 @@ def load_datasets(data_dir_list: Union[List[str], str], concatenate: bool = True
264113
"""
265114
if not isinstance(data_dir_list, list):
266115
data_dir_list = [data_dir_list]
267-
corpus_dict = {}
268116
datasets = []
269117
for data_dir in data_dir_list:
270118
train_data = _load_json_or_jsonl(data_dir)
271119

272-
# Corpus-id based format:
273-
# {
274-
# "corpus": [{"path": "..."}],
275-
# "data": [{"question_id": "...", "question": "...", "corpus_id": "...", "pos_doc": [{"id": "..."}], ...}]
276-
# }
120+
# Corpus-id based format is intentionally not supported in this "inline" loader.
121+
# Use `nemo_automodel.components.datasets.llm.retrieval_dataset.load_datasets` instead.
277122
is_corpus_id_format = isinstance(train_data, dict) and "corpus" in train_data and "data" in train_data
278123
if is_corpus_id_format:
279-
REQUIRED_FIELDS = ["question_id", "question", "corpus_id", "pos_doc", "neg_doc"]
280-
281-
qa_corpus_paths = train_data["corpus"]
282-
add_corpus(qa_corpus_paths, corpus_dict)
283-
284-
normalized_data = []
285-
for item in train_data["data"]:
286-
missing = [f for f in REQUIRED_FIELDS if f not in item]
287-
if missing:
288-
raise ValueError(f"Missing required fields: {missing} in train_data item: {item}")
289-
normalized_item = {
290-
"question_id": item["question_id"],
291-
"question": item["question"],
292-
"corpus_id": item["corpus_id"],
293-
"pos_doc": [_normalize_id_doc(d) for d in _coerce_to_list(item["pos_doc"])],
294-
"neg_doc": [_normalize_id_doc(d) for d in _coerce_to_list(item["neg_doc"])],
295-
}
296-
normalized_data.append(normalized_item)
297-
298-
datasets.append(Dataset.from_list(normalized_data))
299-
continue
124+
raise ValueError(
125+
"Corpus-id retrieval format (top-level 'corpus' + 'data') is not supported by "
126+
"retrieval_dataset_inline. Use retrieval_dataset.py (corpus-id) or convert the dataset "
127+
"to inline JSONL with inline `pos_doc`/`neg_doc` texts."
128+
)
300129

301130
# Inline-text format (JSONL or JSON list/dict). Example record:
302131
# {"query": "...", "pos_doc": "...", "neg_doc": ["...", "..."]}
@@ -347,7 +176,7 @@ def load_datasets(data_dir_list: Union[List[str], str], concatenate: bool = True
347176
dataset = concatenate_datasets(datasets)
348177
else:
349178
dataset = datasets
350-
return (dataset, corpus_dict)
179+
return (dataset, {})
351180

352181

353182
def _transform_func(examples, num_neg_docs, corpus_dict, use_dataset_instruction: bool = False):
@@ -415,7 +244,7 @@ def _transform_func(examples, num_neg_docs, corpus_dict, use_dataset_instruction
415244
cur_corpus_id = corpus_ids[idx_doc]
416245

417246
for doc in docs:
418-
cur_doc = _resolve_doc_to_example(doc, cur_corpus_id, corpus_dict)
247+
cur_doc = _resolve_doc_to_example(doc)
419248

420249
# Extract text
421250
if cur_doc["text"] != "" and not cur_doc["image"]:

tests/unit_tests/datasets/llm/test_retrieval_dataset.py

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -603,34 +603,24 @@ def test_inline_normalization_and_resolution_branches():
603603
with pytest.raises(ValueError, match="Inline doc dict must include 'text'"):
604604
rdi._normalize_inline_doc({"image": "x"})
605605

606-
# _resolve_doc_to_example: id lookup requires corpus present
607-
with pytest.raises(KeyError, match="Corpus 'missing' not found"):
608-
rdi._resolve_doc_to_example({"id": "123"}, corpus_id="missing", corpus_dict={})
609-
610-
corpus_dict = {
611-
"c": DummyCorpus({"x": {"text": "T", "image": "", "nr_ocr": ""}}),
612-
}
613-
614-
# String doc -> treat as id when corpus available
615-
ex = rdi._resolve_doc_to_example("x", corpus_id="c", corpus_dict=corpus_dict)
616-
assert ex["text"] == "T"
606+
# _resolve_doc_to_example: dict missing "text" should raise
607+
with pytest.raises(ValueError, match="Inline doc dict must include 'text'"):
608+
rdi._resolve_doc_to_example({"id": "123"})
617609

618-
# String doc -> treat as inline text when corpus missing
619-
ex2 = rdi._resolve_doc_to_example("hello", corpus_id=rdi.INLINE_CORPUS_ID, corpus_dict={})
620-
assert ex2["text"] == "hello"
610+
# String doc -> treated as inline text
611+
ex = rdi._resolve_doc_to_example("hello")
612+
assert ex["text"] == "hello"
621613

622614
# Inline dict doc: id empty -> use inline fields
623615
inline = rdi._resolve_doc_to_example(
624616
{"id": "", "text": "txt", "image": None, "nr_ocr": 123},
625-
corpus_id=rdi.INLINE_CORPUS_ID,
626-
corpus_dict={},
627617
)
628618
assert inline["text"] == "txt"
629619
assert inline["image"] == "" # None -> ""
630620
assert inline["nr_ocr"] == "123"
631621

632622
# Fallback: non-str doc coerces to string
633-
ex3 = rdi._resolve_doc_to_example(123, corpus_id=rdi.INLINE_CORPUS_ID, corpus_dict={})
623+
ex3 = rdi._resolve_doc_to_example(123)
634624
assert ex3["text"] == "123"
635625

636626

@@ -681,20 +671,10 @@ def test_load_datasets_inline_dict_container_and_error_cases(tmp_path):
681671
rdi.load_datasets(str(f_bad_container))
682672

683673

684-
def test_load_datasets_corpus_id_format_in_inline_module(tmp_path, monkeypatch):
685-
corpus_dir = tmp_path / "corpusA"
686-
corpus_dir.mkdir()
687-
(corpus_dir / "merlin_metadata.json").write_text(json.dumps({"class": "TextQADataset", "corpus_id": "corpusA"}))
688-
689-
# Provide minimal HF dataset for TextQADataset
690-
monkeypatch.setattr(
691-
rdi,
692-
"load_dataset",
693-
_mock_hf_load_dataset_returning([{"id": "p", "text": "P"}, {"id": "n1", "text": "N1"}]),
694-
)
695-
696-
good = {
697-
"corpus": [{"path": str(corpus_dir)}],
674+
def test_load_datasets_corpus_id_format_in_inline_module(tmp_path):
675+
"""The inline loader should reject corpus-id format (use retrieval_dataset.py instead)."""
676+
data = {
677+
"corpus": [{"path": str(tmp_path / "corpus")}],
698678
"data": [
699679
{
700680
"question_id": "q1",
@@ -705,25 +685,10 @@ def test_load_datasets_corpus_id_format_in_inline_module(tmp_path, monkeypatch):
705685
}
706686
],
707687
}
708-
f_good = tmp_path / "train.json"
709-
f_good.write_text(json.dumps(good))
710-
711-
ds, corpus_dict = rdi.load_datasets(str(f_good))
712-
assert len(ds) == 1
713-
assert "corpusA" in corpus_dict
714-
row = ds[0]
715-
assert row["question_id"] == "q1"
716-
assert row["pos_doc"][0]["id"] == "p"
717-
assert row["neg_doc"][0]["id"] == "n1"
718-
719-
bad = {
720-
"corpus": [{"path": str(corpus_dir)}],
721-
"data": [{"question": "Q1", "corpus_id": "corpusA", "pos_doc": [{"id": "p"}], "neg_doc": [{"id": "n1"}]}],
722-
}
723-
f_bad = tmp_path / "bad.json"
724-
f_bad.write_text(json.dumps(bad))
725-
with pytest.raises(ValueError, match="Missing required fields"):
726-
rdi.load_datasets(str(f_bad))
688+
f = tmp_path / "train.json"
689+
f.write_text(json.dumps(data))
690+
with pytest.raises(ValueError, match=r"Corpus-id retrieval format.*not supported.*retrieval_dataset_inline"):
691+
rdi.load_datasets(str(f))
727692

728693

729694
def test_transform_func_inline_error_and_num_neg_docs_zero():
@@ -770,7 +735,12 @@ def test_transform_func_inline_with_dataset_instruction_from_corpus():
770735
passage_instruction="PI",
771736
)
772737
}
773-
examples = {"question": ["Q"], "corpus_id": ["c"], "pos_doc": [[{"id": "p"}]], "neg_doc": [[{"id": "n"}]]}
738+
examples = {
739+
"question": ["Q"],
740+
"corpus_id": ["c"],
741+
"pos_doc": [[{"id": "", "text": "P", "image": "", "nr_ocr": ""}]],
742+
"neg_doc": [[{"id": "", "text": "N", "image": "", "nr_ocr": ""}]],
743+
}
774744
out = rdi._transform_func(examples, num_neg_docs=1, corpus_dict=corpus_dict, use_dataset_instruction=True)
775745
assert out["query_instruction"][0] == "QI"
776746
assert out["passage_instruction"][0] == "PI"

0 commit comments

Comments
 (0)