Skip to content

Commit 080ad70

Browse files
Add validation for required columns
Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
1 parent cff6036 commit 080ad70

File tree

2 files changed

+113
-127
lines changed

2 files changed

+113
-127
lines changed

nemo_automodel/components/datasets/llm/retrieval_dataset.py

Lines changed: 90 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,48 @@ def _load_hf_subset(repo_id: str, subset: str):
296296

297297
corpus_id = metadata["corpus_id"]
298298

299+
if metadata.get("ids_only", False):
300+
raise ValueError(
301+
f"HF subset '{repo_id}/{subset}' has ids_only=true in its metadata, meaning "
302+
f"document and query text must be resolved from an external source before use. "
303+
f"This is not supported for direct HF loading. Either use a subset that contains "
304+
f"inline text, or pre-process the dataset with data_preparation.py and load the "
305+
f"resulting local JSON files via a file path instead."
306+
)
307+
299308
# 2. Load corpus
309+
_CORPUS_REQUIRED_COLS = {"id", "text"}
300310
corpus_hf = load_dataset(repo_id, f"{subset}_corpus", split="train")
301311

312+
missing_cols = _CORPUS_REQUIRED_COLS - set(corpus_hf.column_names)
313+
if missing_cols:
314+
raise ValueError(
315+
f"HF corpus dataset '{repo_id}/{subset}_corpus' does not match the expected schema. "
316+
f"Required columns: {sorted(_CORPUS_REQUIRED_COLS)}, "
317+
f"found columns: {sorted(corpus_hf.column_names)}. "
318+
f"Missing: {sorted(missing_cols)}."
319+
)
320+
302321
# 3. Build HFCorpusDataset + CorpusInfo
303322
hf_corpus = HFCorpusDataset(corpus_hf, path=f"hf://{repo_id}/{subset}")
304323
corpus_info = CorpusInfo(metadata, hf_corpus)
305324

306325
# 4. Load queries
326+
_QUERY_REQUIRED_COLS = {"question", "pos_doc"}
307327
queries_hf = load_dataset(repo_id, subset, split="train")
308328

329+
missing_query_cols = _QUERY_REQUIRED_COLS - set(queries_hf.column_names)
330+
if missing_query_cols:
331+
raise ValueError(
332+
f"HF query dataset '{repo_id}/{subset}' does not match the expected schema. "
333+
f"Required columns: {sorted(_QUERY_REQUIRED_COLS)}, "
334+
f"found columns: {sorted(queries_hf.column_names)}. "
335+
f"Missing: {sorted(missing_query_cols)}."
336+
)
337+
309338
# 5. Normalize to the standard {question_id, question, corpus_id, pos_doc, neg_doc} shape
310-
_HF_REQUIRED_FIELDS = ["question", "pos_doc"]
311339
normalized_data = []
312340
for idx, item in enumerate(queries_hf):
313-
missing = [f for f in _HF_REQUIRED_FIELDS if f not in item]
314-
if missing:
315-
raise ValueError(f"HF subset {repo_id}/{subset} record {idx} missing required fields: {missing}")
316341
normalized_item = {
317342
"question_id": str(item.get("question_id", f"{subset}:{idx}")),
318343
"question": item["question"],
@@ -345,81 +370,30 @@ def _load_hf_subset(repo_id: str, subset: str):
345370
return normalized_data, corpus_info
346371

347372

348-
# ---------------------------------------------------------------------------
349-
# Public helpers for data_sources
350-
# ---------------------------------------------------------------------------
351-
352-
353-
def _normalize_data_source(entry):
354-
"""Normalize a data source entry to a dict with at minimum a ``source`` key."""
355-
if isinstance(entry, str):
356-
return {"source": entry}
357-
elif isinstance(entry, dict):
358-
if "source" not in entry:
359-
raise ValueError(f"Data source dict must include a 'source' key, got keys: {list(entry.keys())}")
360-
return entry
361-
else:
362-
raise TypeError(f"Data source entry must be a str or dict, got {type(entry).__name__}")
363-
364-
365-
def load_datasets_from_sources(data_sources: list, concatenate: bool = True):
366-
"""Load datasets from a list of data sources (``hf://`` URIs and/or local JSON paths).
367-
368-
Returns the same ``(dataset, corpus_dict)`` tuple as :func:`load_datasets`.
369-
"""
373+
def _load_hf_sources(hf_uris: List[str]):
374+
"""Load one or more ``hf://`` URIs and return ``(Dataset, corpus_dict)``."""
370375
hf_data: List[dict] = []
371-
hf_corpus_dict: dict = {}
372-
local_paths: List[str] = []
373-
374-
for entry in data_sources:
375-
entry = _normalize_data_source(entry)
376-
source = entry["source"]
377-
378-
if source.startswith(_HF_PREFIX):
379-
repo_id, subset = _parse_hf_uri(source)
380-
subsets = [subset] if subset is not None else _list_hf_subsets(repo_id)
381-
382-
for sub in subsets:
383-
logging.info(f"Loading HF subset: {repo_id}/{sub}")
384-
data_list, corpus_info = _load_hf_subset(repo_id, sub)
385-
hf_data.extend(data_list)
386-
if corpus_info.corpus_id in hf_corpus_dict:
387-
existing = hf_corpus_dict[corpus_info.corpus_id]
388-
if existing.path != corpus_info.path:
389-
raise ValueError(
390-
f"Duplicate corpus_id '{corpus_info.corpus_id}' with different paths: "
391-
f"{existing.path} vs {corpus_info.path}"
392-
)
393-
else:
394-
hf_corpus_dict[corpus_info.corpus_id] = corpus_info
395-
else:
396-
local_paths.append(source)
397-
398-
datasets_list = []
399-
corpus_dict = dict(hf_corpus_dict)
400-
401-
if hf_data:
402-
datasets_list.append(Dataset.from_list(hf_data))
403-
404-
if local_paths:
405-
local_dataset, local_corpus = load_datasets(local_paths, concatenate=True)
406-
datasets_list.append(local_dataset)
407-
for cid, cinfo in local_corpus.items():
408-
if cid in corpus_dict and corpus_dict[cid].path != cinfo.path:
409-
raise ValueError(
410-
f"Duplicate corpus_id '{cid}' with different paths: {corpus_dict[cid].path} vs {cinfo.path}"
411-
)
412-
corpus_dict[cid] = cinfo
413-
414-
if not datasets_list:
415-
raise ValueError("No datasets loaded from data_sources")
416-
417-
if concatenate:
418-
dataset = concatenate_datasets(datasets_list) if len(datasets_list) > 1 else datasets_list[0]
419-
else:
420-
dataset = datasets_list
376+
corpus_dict: dict = {}
377+
378+
for uri in hf_uris:
379+
repo_id, subset = _parse_hf_uri(uri)
380+
subsets = [subset] if subset is not None else _list_hf_subsets(repo_id)
381+
382+
for sub in subsets:
383+
logging.info(f"Loading HF subset: {repo_id}/{sub}")
384+
data_list, corpus_info = _load_hf_subset(repo_id, sub)
385+
hf_data.extend(data_list)
386+
if corpus_info.corpus_id in corpus_dict:
387+
existing = corpus_dict[corpus_info.corpus_id]
388+
if existing.path != corpus_info.path:
389+
raise ValueError(
390+
f"Duplicate corpus_id '{corpus_info.corpus_id}' with different paths: "
391+
f"{existing.path} vs {corpus_info.path}"
392+
)
393+
else:
394+
corpus_dict[corpus_info.corpus_id] = corpus_info
421395

422-
return dataset, corpus_dict
396+
return Dataset.from_list(hf_data), corpus_dict
423397

424398

425399
def _transform_func(examples, num_neg_docs, corpus_dict, use_dataset_instruction: bool = False):
@@ -543,7 +517,6 @@ def transform(examples):
543517

544518
def make_retrieval_dataset(
545519
data_dir_list: Union[List[str], str] = None,
546-
data_sources: Optional[List[Union[str, dict]]] = None,
547520
data_type: str = "train",
548521
train_n_passages: int = 5,
549522
eval_negative_size: int = 10,
@@ -556,14 +529,13 @@ def make_retrieval_dataset(
556529
"""
557530
Load and return dataset in retrieval format for biencoder training.
558531
559-
This function loads data from JSON files or HuggingFace ``hf://`` URIs and
560-
returns it ready for training. Uses ``set_transform()`` for lazy evaluation —
561-
tokenization is handled by the collator.
532+
Entries in *data_dir_list* can be local JSON file paths **or** ``hf://`` URIs
533+
pointing to a HuggingFace dataset repository (e.g.
534+
``hf://nvidia/embed-nemotron-dataset-v1/SciFact``). Uses ``set_transform()``
535+
for lazy evaluation — tokenization is handled by the collator.
562536
563537
Args:
564-
data_dir_list: Path(s) to JSON file(s) containing training data (legacy).
565-
data_sources: List of ``hf://`` URIs or local paths (preferred).
566-
Exactly one of *data_dir_list* or *data_sources* must be provided.
538+
data_dir_list: Path(s) to JSON file(s) or ``hf://`` URIs.
567539
data_type: Type of data ("train" or "eval")
568540
train_n_passages: Number of passages for training (1 positive + n-1 negatives)
569541
eval_negative_size: Number of negative documents for evaluation
@@ -584,18 +556,35 @@ def make_retrieval_dataset(
584556
which is more efficient for batch padding and supports dynamic processing.
585557
"""
586558

587-
if data_sources is not None and data_dir_list is not None:
588-
raise ValueError("data_dir_list and data_sources are mutually exclusive; provide one, not both")
589-
if data_sources is not None:
590-
logging.info(f"Loading data from {len(data_sources)} source(s)")
591-
dataset, corpus_dict = load_datasets_from_sources(data_sources)
592-
elif data_dir_list is not None:
593-
logging.info(
594-
f"Loading data from {data_dir_list if isinstance(data_dir_list, str) else len(data_dir_list)} file(s)"
595-
)
596-
dataset, corpus_dict = load_datasets(data_dir_list, concatenate=True)
597-
else:
598-
raise ValueError("Either data_dir_list or data_sources must be provided")
559+
if data_dir_list is None:
560+
raise ValueError("data_dir_list is required")
561+
if not isinstance(data_dir_list, list):
562+
data_dir_list = [data_dir_list]
563+
564+
hf_uris = [p for p in data_dir_list if p.startswith(_HF_PREFIX)]
565+
local_paths = [p for p in data_dir_list if not p.startswith(_HF_PREFIX)]
566+
567+
logging.info(f"Loading data from {len(data_dir_list)} source(s) ({len(hf_uris)} HF, {len(local_paths)} local)")
568+
569+
datasets_list = []
570+
corpus_dict: dict = {}
571+
572+
if hf_uris:
573+
hf_dataset, hf_corpus = _load_hf_sources(hf_uris)
574+
datasets_list.append(hf_dataset)
575+
corpus_dict.update(hf_corpus)
576+
577+
if local_paths:
578+
local_dataset, local_corpus = load_datasets(local_paths, concatenate=True)
579+
datasets_list.append(local_dataset)
580+
for cid, cinfo in local_corpus.items():
581+
if cid in corpus_dict and corpus_dict[cid].path != cinfo.path:
582+
raise ValueError(
583+
f"Duplicate corpus_id '{cid}' with different paths: {corpus_dict[cid].path} vs {cinfo.path}"
584+
)
585+
corpus_dict[cid] = cinfo
586+
587+
dataset = concatenate_datasets(datasets_list) if len(datasets_list) > 1 else datasets_list[0]
599588

600589
logging.info(f"Loaded dataset with {len(dataset)} examples")
601590

@@ -629,10 +618,11 @@ def make_retrieval_dataset(
629618

630619
parser = argparse.ArgumentParser(description="Load and transform dataset to retrieval format")
631620
parser.add_argument(
632-
"--data_dir_list", type=str, nargs="+", default=None, help="Path(s) to JSON file(s) containing training data"
633-
)
634-
parser.add_argument(
635-
"--data_sources", type=str, nargs="+", default=None, help="Data source(s): hf:// URIs or local JSON paths"
621+
"--data_dir_list",
622+
type=str,
623+
nargs="+",
624+
required=True,
625+
help="Path(s) to JSON file(s) or hf:// URIs",
636626
)
637627
parser.add_argument(
638628
"--data_type", type=str, default="train", choices=["train", "eval"], help="Type of data (train or eval)"
@@ -653,7 +643,6 @@ def make_retrieval_dataset(
653643

654644
dataset = make_retrieval_dataset(
655645
data_dir_list=args.data_dir_list,
656-
data_sources=args.data_sources,
657646
data_type=args.data_type,
658647
train_n_passages=args.train_n_passages,
659648
eval_negative_size=args.eval_negative_size,

tests/unit_tests/datasets/llm/test_retrieval_dataset.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -838,17 +838,6 @@ def test_parse_hf_uri():
838838
rd._parse_hf_uri("hf://nvidia")
839839

840840

841-
def test_normalize_data_source():
842-
assert rd._normalize_data_source("hf://org/repo/sub") == {"source": "hf://org/repo/sub"}
843-
assert rd._normalize_data_source({"source": "/local/path", "extra": 1}) == {"source": "/local/path", "extra": 1}
844-
845-
with pytest.raises(ValueError, match="'source' key"):
846-
rd._normalize_data_source({"path": "/local/path"})
847-
848-
with pytest.raises(TypeError, match="str or dict"):
849-
rd._normalize_data_source(123)
850-
851-
852841
def test_hf_corpus_dataset():
853842
hf_ds = Dataset.from_list([{"id": "d2", "text": "Doc 2"}, {"id": "d1", "text": "Doc 1"}])
854843
corpus = rd.HFCorpusDataset(hf_ds, path="hf://org/repo/sub")
@@ -923,8 +912,8 @@ def fake_load_dataset(repo_id, config=None, split=None, **kw):
923912
assert corpus_info.get_document_by_id("p1")["text"] == "Positive"
924913

925914

926-
def test_make_retrieval_dataset_data_sources_hf(tmp_path, monkeypatch):
927-
"""End-to-end: make_retrieval_dataset with data_sources pointing to an HF URI."""
915+
def test_make_retrieval_dataset_hf_uri(tmp_path, monkeypatch):
916+
"""End-to-end: make_retrieval_dataset with an hf:// URI in data_dir_list."""
928917
meta_path = tmp_path / "dataset_metadata.json"
929918
meta_path.write_text(
930919
json.dumps({"corpus_id": "e2e_corpus", "class": "TextQADataset", "ids_only": False})
@@ -957,7 +946,7 @@ def fake_load_dataset(repo_id, config=None, split=None, **kw):
957946
monkeypatch.setattr(rd, "load_dataset", fake_load_dataset)
958947

959948
ds = rd.make_retrieval_dataset(
960-
data_sources=["hf://org/repo/SubA"],
949+
data_dir_list=["hf://org/repo/SubA"],
961950
data_type="train",
962951
train_n_passages=3,
963952
)
@@ -987,6 +976,18 @@ def test_transform_func_empty_neg_doc_with_negatives_requested():
987976
assert out["doc_text"][0] == ["pos"]
988977

989978

979+
def test_load_hf_subset_rejects_ids_only(tmp_path, monkeypatch):
980+
"""ids_only subsets should fail fast with a clear message."""
981+
meta_path = tmp_path / "dataset_metadata.json"
982+
meta_path.write_text(
983+
json.dumps({"corpus_id": "c", "class": "TextQADataset", "ids_only": True})
984+
)
985+
monkeypatch.setattr(rd, "hf_hub_download", lambda **kw: str(meta_path))
986+
987+
with pytest.raises(ValueError, match="ids_only=true.*not supported for direct HF loading"):
988+
rd._load_hf_subset("org/repo", "SciFact")
989+
990+
990991
def test_load_hf_subset_synthesizes_question_id(tmp_path, monkeypatch):
991992
"""Records without question_id get deterministic IDs: {subset}:{row_idx}."""
992993
meta_path = tmp_path / "dataset_metadata.json"
@@ -1056,19 +1057,13 @@ def test_make_retrieval_dataset_backwards_compat(tmp_path, monkeypatch):
10561057
assert ex["question"] == "Q0"
10571058

10581059

1059-
def test_make_retrieval_dataset_requires_source():
1060-
"""Neither data_dir_list nor data_sources raises ValueError."""
1061-
with pytest.raises(ValueError, match="Either data_dir_list or data_sources"):
1060+
def test_make_retrieval_dataset_requires_data_dir_list():
1061+
"""Calling without data_dir_list raises ValueError."""
1062+
with pytest.raises(ValueError, match="data_dir_list is required"):
10621063
rd.make_retrieval_dataset()
10631064

10641065

1065-
def test_make_retrieval_dataset_rejects_both_inputs():
1066-
"""Providing both data_dir_list and data_sources raises ValueError."""
1067-
with pytest.raises(ValueError, match="mutually exclusive"):
1068-
rd.make_retrieval_dataset(data_dir_list=["a.json"], data_sources=["hf://org/repo/Sub"])
1069-
1070-
1071-
def test_load_datasets_from_sources_corpus_id_collision_hf_local(tmp_path, monkeypatch):
1066+
def test_make_retrieval_dataset_corpus_id_collision_hf_local(tmp_path, monkeypatch):
10721067
"""HF and local sources with same corpus_id but different paths must raise."""
10731068
# Setup HF side
10741069
meta_path = tmp_path / "dataset_metadata.json"
@@ -1112,8 +1107,10 @@ def fake_load_dataset(repo_id_or_path, config=None, split=None, **kw):
11121107
monkeypatch.setattr(rd, "load_dataset", fake_load_dataset)
11131108

11141109
with pytest.raises(ValueError, match="Duplicate corpus_id.*shared_id.*different paths"):
1115-
rd.load_datasets_from_sources(
1116-
["hf://org/repo/Sub", str(local_file)]
1110+
rd.make_retrieval_dataset(
1111+
data_dir_list=["hf://org/repo/Sub", str(local_file)],
1112+
data_type="train",
1113+
train_n_passages=2,
11171114
)
11181115

11191116

0 commit comments

Comments
 (0)