Skip to content

Commit c9cb743

Browse files
committed
Refactor WebDataset reader and reduce code density
1 parent 39617df commit c9cb743

File tree

2 files changed

+110
-75
lines changed

2 files changed

+110
-75
lines changed

WEBDATASET_REDUCTION_PLAN.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# WebDataset Reader Reduction Plan
2+
3+
## Scope
4+
Target file: `nemo_curator/stages/multimodal/io/readers/webdataset.py`
5+
6+
Goal: reduce density and repeated logic while preserving behavior and test outcomes.
7+
8+
## Principles
9+
1. No behavior changes unless explicitly called out and reviewed.
10+
2. Keep exception semantics stable (`raise`/`skip`/`log`).
11+
3. Run tests after each phase, not only at the end.
12+
4. Prefer removing duplicated branches over adding many small helpers.
13+
14+
## Current Pain Points
15+
1. Interleaved row creation has repeated text/image branch logic.
16+
2. Text-member parsing combines multiple concerns in one path.
17+
3. Fallback/error handling around interleaved parsing is harder to scan than needed.
18+
4. Member-type dispatch has repeated suffix handling patterns.
19+
20+
## Reduction Phases
21+
22+
### Phase 1: Control-flow cleanup (low risk)
23+
1. Keep current behavior, but collapse redundant exception branches where invariant checks already exist.
24+
2. Keep one clear fallback path for `sample_format != "interleaved"`.
25+
3. Ensure all error messages that tests rely on remain unchanged.
26+
27+
Acceptance:
28+
1. Existing multimodal tests pass unchanged.
29+
2. No schema/output changes.
30+
31+
### Phase 2: Duplicate row-construction removal (low-medium risk)
32+
1. Centralize interleaved segment row construction into one reusable function.
33+
2. Reuse that function from all interleaved segment paths.
34+
3. Keep modality checks and metadata JSON population exactly as today.
35+
36+
Acceptance:
37+
1. Interleaved reader/writer roundtrip tests pass.
38+
2. Element metadata JSON behavior remains identical.
39+
40+
### Phase 3: Text-member path decomposition (medium risk)
41+
1. Split JSON text-member path from plain text-member path for readability.
42+
2. Keep shared row finalization in one place to avoid re-duplicating code.
43+
3. Retain existing sample-id/position assignment rules.
44+
45+
Acceptance:
46+
1. Non-interleaved JSON fallback behavior remains unchanged.
47+
2. Metadata sidecar population stays first-wins and identical to current tests.
48+
49+
### Phase 4: Optional line-count pass (optional)
50+
1. Re-evaluate helper count vs readability.
51+
2. Inline helpers that only wrap one call and do not improve clarity.
52+
3. Keep final structure flat and easy to follow.
53+
54+
Acceptance:
55+
1. Net reduction in repeated branches and lines where practical.
56+
2. No decrease in maintainability/readability.
57+
58+
## Test Gate Per Phase
59+
Run:
60+
1. `pytest -q tests/stages/multimodal/test_writer_output_formats.py`
61+
2. `pytest -q tests/stages/multimodal/test_parquet_reader.py`
62+
63+
## Deliverables
64+
1. One commit per phase (or grouped Phase 1+2 if very small).
65+
2. Updated diff summary with:
66+
1. duplicated branches removed
67+
2. net insertions/deletions
68+
3. test results

nemo_curator/stages/multimodal/io/readers/webdataset.py

Lines changed: 42 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ def default_interleaved_field_map() -> dict[str, str]:
160160
def __post_init__(self) -> None:
161161
"""Validate reader configuration."""
162162
super().__post_init__()
163-
if self.sample_format not in _SUPPORTED_SAMPLE_FORMATS:
164-
msg = f"Unsupported sample_format='{self.sample_format}'. Expected one of: auto, simple, interleaved"
165-
raise ValueError(msg)
166-
if self.modalities_to_load not in _SUPPORTED_MODALITIES_TO_LOAD:
167-
msg = f"Unsupported modalities_to_load='{self.modalities_to_load}'. Expected one of: all, image, text"
168-
raise ValueError(msg)
169-
if self.error_handling not in _SUPPORTED_ERROR_HANDLING:
170-
msg = f"Unsupported error_handling='{self.error_handling}'. Expected one of: raise, skip, log"
171-
raise ValueError(msg)
163+
for field_name, value, supported in (
164+
("sample_format", self.sample_format, _SUPPORTED_SAMPLE_FORMATS),
165+
("modalities_to_load", self.modalities_to_load, _SUPPORTED_MODALITIES_TO_LOAD),
166+
("error_handling", self.error_handling, _SUPPORTED_ERROR_HANDLING),
167+
):
168+
if value not in supported:
169+
options = ", ".join(sorted(supported))
170+
msg = f"Unsupported {field_name}='{value}'. Expected one of: {options}"
171+
raise ValueError(msg)
172172
default_map = self.default_interleaved_field_map()
173173
unknown = sorted(set(self.interleaved_field_map or {}) - set(default_map))
174174
if unknown:
@@ -244,25 +244,22 @@ def _rows_from_text_member(
244244
source: RowSource,
245245
) -> list[dict[str, object]]:
246246
if suffix == ".json":
247-
parsed = self._maybe_rows_from_interleaved_json_member(payload, source, state, member_name)
248-
if parsed is not None:
249-
return parsed
247+
if payload is None:
248+
msg = f"JSON member '{member_name}' missing payload bytes"
249+
raise WebDatasetMemberParseError(msg)
250+
try:
251+
return self._rows_from_interleaved_json(payload, source, state)
252+
except WebDatasetMemberParseError:
253+
if self.sample_format == "interleaved":
254+
raise
250255
if not self._loads_modality("text"):
251256
return []
252257
sid, position = self._next_sample_and_position(state.sample_counters, member_name, "text")
253258
text_content = self._decode_text_payload(payload, member_name)
254259
content_type = "application/json" if suffix == ".json" else "text/plain"
255260
if suffix == ".json":
256261
_record_metadata_row(state, sid, text_content or "{}")
257-
return [
258-
self._text_row(
259-
sid=sid,
260-
position=position,
261-
source_shard=source.source_shard,
262-
content_type=content_type,
263-
text_content=text_content,
264-
)
265-
]
262+
return [self._text_row(sid=sid, position=position, source_shard=source.source_shard, content_type=content_type, text_content=text_content)]
266263

267264
def _rows_from_interleaved_json(
268265
self,
@@ -275,11 +272,11 @@ def _rows_from_interleaved_json(
275272
sample_payload = dict(decoded)
276273
sample_payload.pop(self.interleaved_field_map["segments"], None)
277274
_record_metadata_row(state, sample_id, self._json_or_none(sample_payload) or "{}")
278-
rows: list[dict[str, object]] = []
279275
field_map = self.interleaved_field_map
280276
modality_field = field_map["modality"]
281277
text_field = field_map["text"]
282278
content_key_field = field_map["content_key"]
279+
rows: list[dict[str, object]] = []
283280
for idx, segment in enumerate(segments):
284281
modality = _required_segment_str(segment, modality_field)
285282
if modality not in _SUPPORTED_INTERLEAVED_MODALITIES:
@@ -288,53 +285,31 @@ def _rows_from_interleaved_json(
288285
"in WebDatasetReaderStage (supported: text, image)"
289286
)
290287
raise WebDatasetMemberParseError(msg)
291-
if self._loads_modality(modality):
292-
if modality == "text":
293-
rows.append(
294-
self._text_row(
295-
sid=sample_id,
296-
position=idx,
297-
source_shard=source.source_shard,
298-
content_type="text/plain",
299-
text_content=_required_segment_str(segment, text_field),
300-
element_metadata_json=self._json_or_none(segment),
301-
)
302-
)
303-
else:
304-
rows.append(
305-
self._image_row(
306-
sid=sample_id,
307-
position=idx,
308-
source=source,
309-
content_key=_required_segment_str(segment, content_key_field),
310-
element_metadata_json=self._json_or_none(segment),
311-
)
288+
if not self._loads_modality(modality):
289+
continue
290+
if modality == "text":
291+
rows.append(
292+
self._text_row(
293+
sid=sample_id,
294+
position=idx,
295+
source_shard=source.source_shard,
296+
content_type="text/plain",
297+
text_content=_required_segment_str(segment, text_field),
298+
element_metadata_json=self._json_or_none(segment),
312299
)
300+
)
301+
continue
302+
rows.append(
303+
self._image_row(
304+
sid=sample_id,
305+
position=idx,
306+
source=source,
307+
content_key=_required_segment_str(segment, content_key_field),
308+
element_metadata_json=self._json_or_none(segment),
309+
)
310+
)
313311
return rows
314312

315-
def _maybe_rows_from_interleaved_json_member(
316-
self,
317-
payload: bytes | None,
318-
source: RowSource,
319-
state: RowBuildState,
320-
member_name: str,
321-
) -> list[dict[str, object]] | None:
322-
if payload is None:
323-
msg = f"JSON member '{member_name}' missing payload bytes"
324-
raise WebDatasetMemberParseError(msg)
325-
try:
326-
parsed = self._rows_from_interleaved_json(payload, source, state)
327-
except WebDatasetMemberParseError:
328-
if self.sample_format == "interleaved":
329-
raise
330-
return None
331-
except KeyError as err:
332-
if self.sample_format == "interleaved":
333-
msg = f"Interleaved JSON missing required field: {err}"
334-
raise WebDatasetMemberParseError(msg) from err
335-
return None
336-
return parsed
337-
338313
@staticmethod
339314
def _decode_text_payload(payload: bytes | None, member_name: str) -> str:
340315
if payload is None:
@@ -368,15 +343,7 @@ def _rows_from_binary_member(
368343
if not self._loads_modality(modality):
369344
return []
370345
sid, position = self._next_sample_and_position(state.sample_counters, member_name, modality)
371-
return [
372-
self._image_row(
373-
sid=sid,
374-
position=position,
375-
source=source,
376-
content_key=member_name,
377-
binary_content=payload if self.load_binary else None,
378-
)
379-
]
346+
return [self._image_row(sid=sid, position=position, source=source, content_key=member_name, binary_content=payload if self.load_binary else None)]
380347

381348
def _next_sample_and_position(
382349
self,

0 commit comments

Comments
 (0)