Skip to content

Commit 793c4bb

Browse files
committed
Preserve per-row text metadata in collapsed WebDataset segments
1 parent 3afbc97 commit 793c4bb

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

nemo_curator/stages/multimodal/io/writers/multimodal.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,17 @@ def _write_webdataset_to_fileobj(task: MultimodalBatch, fileobj: BinaryIO) -> No
200200

201201
first_text_index: dict[str, int] = {}
202202
text_row_count: dict[str, int] = {}
203-
merged_text_payload: dict[str, bytes] = {}
204203
text_segments: dict[str, list[dict[str, object]]] = {}
205204
has_text_segment_metadata: dict[str, bool] = {}
206205
for idx, (sample_id, modality) in enumerate(zip(sample_ids, modalities, strict=True)):
207206
sid = str(sample_id)
208207
if str(modality) == "text":
209208
if sid not in first_text_index:
210209
first_text_index[sid] = idx
211-
merged_text_payload[sid] = b""
212210
text_row_count[sid] = 0
213211
text_segments[sid] = []
214212
has_text_segment_metadata[sid] = False
215213
text_row_count[sid] += 1
216-
current = merged_text_payload[sid]
217-
text_bytes = str(text_contents[idx] or "").encode("utf-8")
218-
merged_text_payload[sid] = text_bytes if current == b"" else current + b"\n" + text_bytes
219214
segment: dict[str, object] = {"modality": "text", "text": str(text_contents[idx] or "")}
220215
text_row_metadata = MultimodalWriterStage._parse_json_or_raw(element_metadata_jsons[idx])
221216
if text_row_metadata is not None:
@@ -233,7 +228,6 @@ def _write_webdataset_to_fileobj(task: MultimodalBatch, fileobj: BinaryIO) -> No
233228
suffix, payload = MultimodalWriterStage._text_suffix_and_payload(
234229
sample_id=sid,
235230
content_type=content_types[idx],
236-
merged_text_payload=merged_text_payload[sid],
237231
text_segments=text_segments[sid],
238232
include_segment_metadata=has_text_segment_metadata[sid],
239233
)
@@ -257,7 +251,6 @@ def _text_suffix_and_payload(
257251
*,
258252
sample_id: str,
259253
content_type: object | None,
260-
merged_text_payload: bytes,
261254
text_segments: list[dict[str, object]],
262255
include_segment_metadata: bool,
263256
) -> tuple[str, bytes]:
@@ -273,7 +266,8 @@ def _text_suffix_and_payload(
273266
return "json", json.dumps(payload, ensure_ascii=True).encode("utf-8")
274267
ctype = str(content_type) if content_type is not None else "text/plain"
275268
suffix = "json" if ctype == "application/json" else "txt"
276-
return suffix, merged_text_payload
269+
text_payload = "\n".join(str(segment.get("text", "")) for segment in text_segments)
270+
return suffix, text_payload.encode("utf-8")
277271

278272
@staticmethod
279273
def _parse_json_or_raw(value: object | None) -> object | None:

tests/stages/multimodal/test_writer_output_formats.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,62 @@ def test_webdataset_writer_collapsed_text_preserves_element_metadata_json(tmp_pa
614614
assert json.loads(str(rows[1]["element_metadata_json"]))["element_metadata_json"]["lang"] == "en"
615615

616616

617+
def test_webdataset_writer_collapsed_text_writes_full_segment_metadata_payload(tmp_path: Path) -> None:
618+
out = tmp_path / "collapsed_text_full_metadata.tar"
619+
table = pa.table(
620+
{
621+
"sample_id": ["doc", "doc", "doc", "doc"],
622+
"position": [0, 1, 2, 3],
623+
"modality": ["text", "text", "text", "image"],
624+
"content_type": ["text/plain", "text/plain", "text/plain", "image/jpeg"],
625+
"text_content": ["alpha", "beta", "gamma", None],
626+
"binary_content": [None, None, None, b"img"],
627+
"element_metadata_json": [
628+
'{"quality": 0.91, "token_count": 1}',
629+
'{"quality": 0.77, "lang": "en", "attrs": {"source": "ocr"}}',
630+
'{"quality": 0.55, "tags": ["x", "y"]}',
631+
None,
632+
],
633+
"source_id": ["src", "src", "src", "src"],
634+
"source_shard": ["shard", "shard", "shard", "shard"],
635+
"content_path": [None, None, None, None],
636+
"content_key": [None, None, None, "doc.jpg"],
637+
},
638+
schema=MULTIMODAL_SCHEMA,
639+
)
640+
task = MultimodalBatch(task_id="t-full-meta", dataset_name="ds", data=table)
641+
result = MultimodalWriterStage(output_path=str(out), output_format="webdataset").process(task)
642+
names, members = _read_tar_members(Path(result.data[0]))
643+
644+
assert names == ["doc.000000.json", "doc.000003.jpg"]
645+
payload = json.loads(members["doc.000000.json"].decode("utf-8"))
646+
assert payload["sample_id"] == "doc"
647+
assert [segment["text"] for segment in payload["segments"]] == ["alpha", "beta", "gamma"]
648+
assert payload["segments"][0]["element_metadata_json"] == {"quality": 0.91, "token_count": 1}
649+
assert payload["segments"][1]["element_metadata_json"] == {
650+
"quality": 0.77,
651+
"lang": "en",
652+
"attrs": {"source": "ocr"},
653+
}
654+
assert payload["segments"][2]["element_metadata_json"] == {"quality": 0.55, "tags": ["x", "y"]}
655+
656+
roundtrip = WebDatasetReaderStage(load_binary=False, sample_format="auto").process(
657+
FileGroupTask(task_id="rt-full-meta", dataset_name="ds", data=[result.data[0]])
658+
)
659+
rows = sorted(
660+
[row for row in roundtrip.data.to_pylist() if row["modality"] == "text"],
661+
key=lambda row: int(row["position"]),
662+
)
663+
assert [row["text_content"] for row in rows] == ["alpha", "beta", "gamma"]
664+
assert json.loads(str(rows[0]["element_metadata_json"]))["element_metadata_json"] == {"quality": 0.91, "token_count": 1}
665+
assert json.loads(str(rows[1]["element_metadata_json"]))["element_metadata_json"] == {
666+
"quality": 0.77,
667+
"lang": "en",
668+
"attrs": {"source": "ocr"},
669+
}
670+
assert json.loads(str(rows[2]["element_metadata_json"]))["element_metadata_json"] == {"quality": 0.55, "tags": ["x", "y"]}
671+
672+
617673
def test_webdataset_writer_allows_text_only_batch(tmp_path: Path) -> None:
618674
out = tmp_path / "text-only.tar"
619675
table = pa.table(

0 commit comments

Comments
 (0)