Skip to content

Commit 96ec3da

Browse files
committed
Refine WebDataset validation helpers and simplify init checks
1 parent c9cb743 commit 96ec3da

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

nemo_curator/stages/multimodal/io/readers/webdataset.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,26 @@ def _validate_interleaved_payload(
119119
return sample_id, typed_segments
120120

121121

122+
def _validate_option(value: str, *, field_name: str, supported: set[str], options_label: str) -> None:
123+
if value not in supported:
124+
msg = f"Unsupported {field_name}='{value}'. Expected one of: {options_label}"
125+
raise ValueError(msg)
126+
127+
128+
def _resolve_interleaved_field_map(overrides: dict[str, str] | None, default_map: dict[str, str]) -> dict[str, str]:
129+
unknown = sorted(set(overrides or {}) - set(default_map))
130+
if unknown:
131+
msg = f"interleaved_field_map has unknown keys: {unknown}"
132+
raise ValueError(msg)
133+
resolved = default_map
134+
resolved.update(overrides or {})
135+
for semantic, actual in resolved.items():
136+
if not isinstance(actual, str) or not actual:
137+
msg = f"interleaved_field_map['{semantic}'] must be a non-empty string"
138+
raise ValueError(msg)
139+
return resolved
140+
141+
122142
@dataclass
123143
class WebDatasetReaderStage(BaseMultimodalReaderStage):
124144
"""Parse WebDataset tar shards into normalized multimodal rows.
@@ -160,27 +180,13 @@ def default_interleaved_field_map() -> dict[str, str]:
160180
def __post_init__(self) -> None:
161181
"""Validate reader configuration."""
162182
super().__post_init__()
163-
for field_name, value, supported in (
164-
("sample_format", self.sample_format, _SUPPORTED_SAMPLE_FORMATS),
165-
("modalities_to_load", self.modalities_to_load, _SUPPORTED_MODALITIES_TO_LOAD),
166-
("error_handling", self.error_handling, _SUPPORTED_ERROR_HANDLING),
167-
):
168-
if value not in supported:
169-
options = ", ".join(sorted(supported))
170-
msg = f"Unsupported {field_name}='{value}'. Expected one of: {options}"
171-
raise ValueError(msg)
172-
default_map = self.default_interleaved_field_map()
173-
unknown = sorted(set(self.interleaved_field_map or {}) - set(default_map))
174-
if unknown:
175-
msg = f"interleaved_field_map has unknown keys: {unknown}"
176-
raise ValueError(msg)
177-
resolved = default_map
178-
resolved.update(self.interleaved_field_map or {})
179-
for semantic, actual in resolved.items():
180-
if not isinstance(actual, str) or not actual:
181-
msg = f"interleaved_field_map['{semantic}'] must be a non-empty string"
182-
raise ValueError(msg)
183-
self.interleaved_field_map = resolved
183+
_validate_option(self.sample_format, field_name="sample_format", supported=_SUPPORTED_SAMPLE_FORMATS, options_label="auto, simple, interleaved")
184+
_validate_option(self.modalities_to_load, field_name="modalities_to_load", supported=_SUPPORTED_MODALITIES_TO_LOAD, options_label="all, image, text")
185+
_validate_option(self.error_handling, field_name="error_handling", supported=_SUPPORTED_ERROR_HANDLING, options_label="raise, skip, log")
186+
self.interleaved_field_map = _resolve_interleaved_field_map(
187+
self.interleaved_field_map,
188+
self.default_interleaved_field_map(),
189+
)
184190

185191
def read_data(self, data_path: str, metadata_path: str | None) -> tuple[pa.Table, pa.Table]:
186192
"""Read one tar shard into normalized data and metadata tables.

0 commit comments

Comments
 (0)