Skip to content

Commit 53ed62a

Browse files
committed
Clean up WebDataset parsing flow and explicit sample format handling
1 parent 96ec3da commit 53ed62a

File tree

2 files changed

+39
-23
lines changed

2 files changed

+39
-23
lines changed

nemo_curator/stages/multimodal/io/readers/webdataset.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -200,13 +200,14 @@ def read_data(self, data_path: str, metadata_path: str | None) -> tuple[pa.Table
200200

201201
with open_tar_path(data_path, self.storage_options) as tf:
202202
for member in tf:
203-
if member.isfile():
204-
member_name = member.name
205-
try:
206-
payload = self._member_payload(tf, member_name, member)
207-
rows.extend(self._rows_from_member(state, member_name, payload, source))
208-
except (OSError, UnicodeDecodeError, json.JSONDecodeError, WebDatasetMemberParseError) as err:
209-
self._handle_member_error(member_name, err)
203+
if not member.isfile():
204+
continue
205+
member_name = member.name
206+
try:
207+
payload = self._member_payload(tf, member_name, member)
208+
rows.extend(self._rows_from_member(state, member_name, payload, source))
209+
except (OSError, WebDatasetMemberParseError) as err:
210+
self._handle_member_error(member_name, err)
210211
return self._rows_to_table(rows), pa.Table.from_pylist(state.metadata_rows, schema=METADATA_SCHEMA)
211212

212213
def _member_payload(self, tf: tarfile.TarFile, member_name: str, member: tarfile.TarInfo) -> bytes | None:
@@ -273,7 +274,7 @@ def _rows_from_interleaved_json(
273274
source: RowSource,
274275
state: RowBuildState,
275276
) -> list[dict[str, object]]:
276-
decoded = json.loads(payload.decode("utf-8"))
277+
decoded = self._decode_json_payload(payload, context="interleaved JSON payload")
277278
sample_id, segments = _validate_interleaved_payload(decoded, self.interleaved_field_map)
278279
sample_payload = dict(decoded)
279280
sample_payload.pop(self.interleaved_field_map["segments"], None)
@@ -294,34 +295,46 @@ def _rows_from_interleaved_json(
294295
if not self._loads_modality(modality):
295296
continue
296297
if modality == "text":
297-
rows.append(
298-
self._text_row(
299-
sid=sample_id,
300-
position=idx,
301-
source_shard=source.source_shard,
302-
content_type="text/plain",
303-
text_content=_required_segment_str(segment, text_field),
304-
element_metadata_json=self._json_or_none(segment),
305-
)
298+
row = self._text_row(
299+
sid=sample_id,
300+
position=idx,
301+
source_shard=source.source_shard,
302+
content_type="text/plain",
303+
text_content=_required_segment_str(segment, text_field),
304+
element_metadata_json=self._json_or_none(segment),
306305
)
307-
continue
308-
rows.append(
309-
self._image_row(
306+
else:
307+
row = self._image_row(
310308
sid=sample_id,
311309
position=idx,
312310
source=source,
313311
content_key=_required_segment_str(segment, content_key_field),
314312
element_metadata_json=self._json_or_none(segment),
315313
)
316-
)
314+
rows.append(row)
317315
return rows
318316

319317
@staticmethod
320318
def _decode_text_payload(payload: bytes | None, member_name: str) -> str:
321319
if payload is None:
322320
msg = f"Text member '{member_name}' missing payload bytes"
323321
raise WebDatasetMemberParseError(msg)
324-
return payload.decode("utf-8") if payload else ""
322+
try:
323+
return payload.decode("utf-8") if payload else ""
324+
except UnicodeDecodeError as err:
325+
msg = f"Text member '{member_name}' must be valid UTF-8"
326+
raise WebDatasetMemberParseError(msg) from err
327+
328+
@staticmethod
329+
def _decode_json_payload(payload: bytes, *, context: str) -> object:
330+
try:
331+
return json.loads(payload.decode("utf-8"))
332+
except UnicodeDecodeError as err:
333+
msg = f"{context} must be valid UTF-8 JSON"
334+
raise WebDatasetMemberParseError(msg) from err
335+
except json.JSONDecodeError as err:
336+
msg = f"{context} must be valid JSON"
337+
raise WebDatasetMemberParseError(msg) from err
325338

326339
@staticmethod
327340
def _binary_modality_for_member(member_name: str) -> str:

nemo_curator/utils/webdataset_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@ def parse_sample_and_position(
5050
return stem, 0 if modality == "image" else 1
5151
if sample_format == "interleaved":
5252
return stem, fallback_position
53-
return stem, 0 if modality == "text" else 1
53+
if sample_format == "auto":
54+
return stem, 0 if modality == "text" else 1
55+
msg = f"Unsupported sample_format='{sample_format}'. Expected one of: auto, simple, interleaved"
56+
raise ValueError(msg)
5457

5558

5659
def content_type_from_name(name: str, default: str = DEFAULT_BINARY_CONTENT_TYPE) -> str:

0 commit comments

Comments
 (0)