Skip to content

Commit 3afbc97

Browse files
committed
Tighten multimodal reader/writer contracts and reduce reader complexity
1 parent 3dbb09f commit 3afbc97

File tree

11 files changed

+537
-251
lines changed

11 files changed

+537
-251
lines changed

nemo_curator/stages/multimodal/io/readers/base.py

Lines changed: 10 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import json
1212
from abc import ABC, abstractmethod
13-
from collections import OrderedDict
1413
from dataclasses import dataclass, field
1514
from typing import TYPE_CHECKING, Any
1615

@@ -23,14 +22,17 @@
2322
from nemo_curator.tasks.multimodal import METADATA_SCHEMA, MULTIMODAL_SCHEMA
2423
from nemo_curator.utils.file_utils import resolve_fs_and_path
2524
from nemo_curator.utils.grouping import split_by_chunk_size
26-
from nemo_curator.utils.multimodal_utils import sort_multimodal_table
25+
from nemo_curator.utils.multimodal_utils import (
26+
metadata_map_from_tables,
27+
metadata_rows_for_table,
28+
sort_multimodal_table,
29+
)
2730
from nemo_curator.utils.webdataset_utils import content_type_from_name
2831

2932
if TYPE_CHECKING:
3033
from collections.abc import Iterable
3134

3235
ReaderTask = FileGroupTask | tuple[FileGroupTask, FileGroupTask | None]
33-
_PAIR_ELEMENT_COUNT = 2
3436

3537

3638
@dataclass
@@ -128,7 +130,7 @@ def _build_batches_from_tables(
128130
) -> MultimodalBatch | list[MultimodalBatch]:
129131
table = self._concat_data_tables_or_empty(data_tables)
130132
table = sort_multimodal_table(table)
131-
metadata_by_sample = self._metadata_map_from_tables(metadata_tables)
133+
metadata_by_sample = metadata_map_from_tables(metadata_tables)
132134
table_splits = self.split_table(table)
133135
batches = [
134136
self._build_batch(
@@ -181,7 +183,7 @@ def split_table_by_sample_max_bytes(self, table: pa.Table, max_batch_bytes: int)
181183
"""Split table by sample groups while preserving sample row locality."""
182184
if table.num_rows == 0:
183185
return [table]
184-
row_indices_by_sample: OrderedDict[str, list[int]] = OrderedDict()
186+
row_indices_by_sample: dict[str, list[int]] = {}
185187
for idx, sample_id in enumerate(table["sample_id"].to_pylist()):
186188
sid = str(sample_id)
187189
row_indices_by_sample.setdefault(sid, [])
@@ -205,6 +207,7 @@ def _text_row( # noqa: PLR0913
205207
content_type: str,
206208
text_content: str,
207209
element_metadata_json: str | None = None,
210+
source_id: str | None = None,
208211
) -> dict[str, object]:
209212
"""Build one normalized text row payload."""
210213
return {
@@ -215,7 +218,7 @@ def _text_row( # noqa: PLR0913
215218
"text_content": text_content,
216219
"binary_content": None,
217220
"element_metadata_json": element_metadata_json,
218-
"source_id": sid,
221+
"source_id": source_id or sid,
219222
"source_shard": source_shard,
220223
"content_path": None,
221224
"content_key": None,
@@ -284,24 +287,6 @@ def _task_metadata(self, task: FileGroupTask) -> dict[str, Any]:
284287
"""Propagate task metadata and attach storage options used for reads."""
285288
return {**task._metadata, "storage_options": dict(self.storage_options)}
286289

287-
@staticmethod
288-
def _metadata_map_from_tables(metadata_tables: list[pa.Table]) -> dict[str, str]:
289-
"""Build first-wins sample->metadata_json map from metadata tables."""
290-
metadata_by_sample: dict[str, str] = {}
291-
for metadata_table in metadata_tables:
292-
has_rows = metadata_table.num_rows > 0
293-
has_sample_id = "sample_id" in metadata_table.column_names
294-
if has_rows and has_sample_id:
295-
sample_ids = metadata_table["sample_id"].to_pylist()
296-
if "metadata_json" in metadata_table.column_names:
297-
metadata_json_values = metadata_table["metadata_json"].to_pylist()
298-
else:
299-
metadata_json_values = [None] * len(sample_ids)
300-
for sample_id, metadata_json in zip(sample_ids, metadata_json_values, strict=True):
301-
if isinstance(metadata_json, str):
302-
metadata_by_sample.setdefault(str(sample_id), metadata_json)
303-
return metadata_by_sample
304-
305290
def _build_batch(
306291
self,
307292
task: FileGroupTask,
@@ -311,7 +296,7 @@ def _build_batch(
311296
split_output: bool,
312297
) -> MultimodalBatch:
313298
"""Assemble one ``MultimodalBatch`` from normalized data and metadata."""
314-
metadata_rows = self._metadata_rows_for_table(table, metadata_by_sample)
299+
metadata_rows = metadata_rows_for_table(table, metadata_by_sample)
315300
metadata_table = (
316301
pa.Table.from_pylist(metadata_rows, schema=METADATA_SCHEMA)
317302
if metadata_rows
@@ -326,45 +311,3 @@ def _build_batch(
326311
_metadata=self._task_metadata(task),
327312
_stage_perf=task._stage_perf,
328313
)
329-
330-
@staticmethod
331-
def _metadata_rows_for_table(
332-
table: pa.Table,
333-
metadata_by_sample: dict[str, str],
334-
) -> list[dict[str, object]]:
335-
"""Build metadata rows with single-pass sample type inference."""
336-
if table.num_rows == 0:
337-
return []
338-
339-
sample_stats: OrderedDict[str, tuple[int, bool, bool]] = OrderedDict()
340-
for sample_id, modality in zip(table["sample_id"].to_pylist(), table["modality"].to_pylist(), strict=True):
341-
sid = str(sample_id)
342-
modality_name = str(modality)
343-
count, has_image, has_text = sample_stats.get(sid, (0, False, False))
344-
sample_stats[sid] = (
345-
count + 1,
346-
has_image or modality_name == "image",
347-
has_text or modality_name == "text",
348-
)
349-
350-
return [
351-
{
352-
"sample_id": sid,
353-
"sample_type": BaseMultimodalReaderStage._sample_type_from_summary(
354-
num_rows=num_rows,
355-
has_image=has_image,
356-
has_text=has_text,
357-
),
358-
"metadata_json": metadata_by_sample.get(sid),
359-
}
360-
for sid, (num_rows, has_image, has_text) in sample_stats.items()
361-
]
362-
363-
@staticmethod
364-
def _sample_type_from_summary(num_rows: int, has_image: bool, has_text: bool) -> str:
365-
"""Infer sample type from in-sample modality ordering."""
366-
if num_rows == 1:
367-
return "single"
368-
if num_rows == _PAIR_ELEMENT_COUNT and has_image and has_text:
369-
return "pair"
370-
return "interleaved"

nemo_curator/stages/multimodal/io/readers/parquet.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,35 @@ def __post_init__(self) -> None:
5757
@staticmethod
5858
def _validate_columns(columns: list[str] | None) -> list[str] | None:
5959
"""Validate optional data column selection."""
60-
return ParquetMultimodalReaderStage._validate_column_selection(columns, option_name="columns")
60+
if columns is None:
61+
return None
62+
if len(columns) == 0:
63+
msg = "columns must be a non-empty list when provided"
64+
raise ValueError(msg)
65+
seen: set[str] = set()
66+
normalized: list[str] = []
67+
for column in columns:
68+
if not isinstance(column, str) or not column:
69+
msg = "columns entries must be non-empty strings"
70+
raise ValueError(msg)
71+
if column not in seen:
72+
seen.add(column)
73+
normalized.append(column)
74+
return normalized
6175

6276
@staticmethod
6377
def _validate_metadata_columns(columns: list[str] | None) -> list[str] | None:
6478
"""Validate optional metadata sidecar column selection."""
65-
return ParquetMultimodalReaderStage._validate_column_selection(columns, option_name="metadata_columns")
66-
67-
@staticmethod
68-
def _validate_column_selection(columns: list[str] | None, option_name: str) -> list[str] | None:
69-
"""Validate and de-duplicate a selected column list."""
7079
if columns is None:
7180
return None
7281
if len(columns) == 0:
73-
msg = f"{option_name} must be a non-empty list when provided"
82+
msg = "metadata_columns must be a non-empty list when provided"
7483
raise ValueError(msg)
7584
seen: set[str] = set()
7685
normalized: list[str] = []
7786
for column in columns:
7887
if not isinstance(column, str) or not column:
79-
msg = f"{option_name} entries must be non-empty strings"
88+
msg = "metadata_columns entries must be non-empty strings"
8089
raise ValueError(msg)
8190
if column not in seen:
8291
seen.add(column)
@@ -165,12 +174,6 @@ class ParquetMultimodalReader(CompositeStage[_EmptyTask, MultimodalBatch]):
165174

166175
def __post_init__(self) -> None:
167176
super().__init__()
168-
if isinstance(self.file_paths, str) and not self.file_paths.endswith(".parquet"):
169-
msg = (
170-
"When file_paths is a string, it must point to a .parquet file. "
171-
"Use an explicit list of parquet file paths when reading multiple files."
172-
)
173-
raise ValueError(msg)
174177

175178
def decompose(self) -> list[ProcessingStage]:
176179
return [

nemo_curator/stages/multimodal/io/readers/webdataset.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
InterleavedSegment = dict[str, object]
5050

5151

52+
class WebDatasetMemberParseError(ValueError):
53+
"""Expected parse/validation failure for one WebDataset member."""
54+
55+
5256
@dataclass
5357
class RowBuildState:
5458
"""Per-shard mutable parse state.
@@ -82,7 +86,7 @@ def _required_segment_str(segment: InterleavedSegment, field: str) -> str:
8286
value = segment.get(field)
8387
if not isinstance(value, str) or not value:
8488
msg = f"Interleaved segment must include non-empty string '{field}'"
85-
raise ValueError(msg)
89+
raise WebDatasetMemberParseError(msg)
8690
return value
8791

8892

@@ -92,25 +96,25 @@ def _validate_interleaved_payload(
9296
) -> tuple[str, list[InterleavedSegment]]:
9397
if not isinstance(decoded, dict):
9498
msg = "Interleaved JSON payload must decode to an object"
95-
raise TypeError(msg)
99+
raise WebDatasetMemberParseError(msg)
96100

97101
sample_id_field = field_map["sample_id"]
98102
segments_field = field_map["segments"]
99103
sample_id = decoded.get(sample_id_field)
100104
if not isinstance(sample_id, str) or not sample_id:
101105
msg = f"Interleaved JSON payload must include non-empty string '{sample_id_field}'"
102-
raise ValueError(msg)
106+
raise WebDatasetMemberParseError(msg)
103107

104108
segments = decoded.get(segments_field)
105109
if not isinstance(segments, list):
106110
msg = f"Interleaved JSON payload must include list field '{segments_field}'"
107-
raise TypeError(msg)
111+
raise WebDatasetMemberParseError(msg)
108112

109113
typed_segments: list[InterleavedSegment] = []
110114
for idx, segment in enumerate(segments):
111115
if not isinstance(segment, dict):
112116
msg = f"Interleaved segment at index={idx} for sample_id='{sample_id}' must be an object"
113-
raise TypeError(msg)
117+
raise WebDatasetMemberParseError(msg)
114118
typed_segments.append(segment)
115119
return sample_id, typed_segments
116120

@@ -195,7 +199,7 @@ def read_data(self, data_path: str, metadata_path: str | None) -> tuple[pa.Table
195199
try:
196200
payload = self._member_payload(tf, member_name, member)
197201
rows.extend(self._rows_from_member(state, member_name, payload, source))
198-
except Exception as err: # noqa: BLE001
202+
except (OSError, UnicodeDecodeError, json.JSONDecodeError, WebDatasetMemberParseError) as err:
199203
self._handle_member_error(member_name, err)
200204
return self._rows_to_table(rows), pa.Table.from_pylist(state.metadata_rows, schema=METADATA_SCHEMA)
201205

@@ -283,7 +287,7 @@ def _rows_from_interleaved_json(
283287
f"Unsupported interleaved modality='{modality}' for sample_id='{sample_id}' "
284288
"in WebDatasetReaderStage (supported: text, image)"
285289
)
286-
raise ValueError(msg)
290+
raise WebDatasetMemberParseError(msg)
287291
if self._loads_modality(modality):
288292
if modality == "text":
289293
rows.append(
@@ -317,20 +321,25 @@ def _maybe_rows_from_interleaved_json_member(
317321
) -> list[dict[str, object]] | None:
318322
if payload is None:
319323
msg = f"JSON member '{member_name}' missing payload bytes"
320-
raise ValueError(msg)
324+
raise WebDatasetMemberParseError(msg)
321325
try:
322326
parsed = self._rows_from_interleaved_json(payload, source, state)
323-
except (KeyError, TypeError, ValueError):
327+
except WebDatasetMemberParseError:
324328
if self.sample_format == "interleaved":
325329
raise
326330
return None
331+
except KeyError as err:
332+
if self.sample_format == "interleaved":
333+
msg = f"Interleaved JSON missing required field: {err}"
334+
raise WebDatasetMemberParseError(msg) from err
335+
return None
327336
return parsed
328337

329338
@staticmethod
330339
def _decode_text_payload(payload: bytes | None, member_name: str) -> str:
331340
if payload is None:
332341
msg = f"Text member '{member_name}' missing payload bytes"
333-
raise ValueError(msg)
342+
raise WebDatasetMemberParseError(msg)
334343
return payload.decode("utf-8") if payload else ""
335344

336345
@staticmethod
@@ -339,13 +348,13 @@ def _binary_modality_for_member(member_name: str) -> str:
339348
modality = modality_from_content_type(content_type)
340349
if modality == "unknown":
341350
msg = f"Unsupported content_type='{content_type}' for member '{member_name}' in WebDatasetReaderStage"
342-
raise ValueError(msg)
351+
raise WebDatasetMemberParseError(msg)
343352
if modality != "image":
344353
msg = (
345354
f"Unsupported binary modality='{modality}' for member '{member_name}' "
346355
"in WebDatasetReaderStage (supported: image)"
347356
)
348-
raise ValueError(msg)
357+
raise WebDatasetMemberParseError(msg)
349358
return modality
350359

351360
def _rows_from_binary_member(

0 commit comments

Comments
 (0)