1010
1111import json
1212from abc import ABC , abstractmethod
13- from collections import OrderedDict
1413from dataclasses import dataclass , field
1514from typing import TYPE_CHECKING , Any
1615
2322from nemo_curator .tasks .multimodal import METADATA_SCHEMA , MULTIMODAL_SCHEMA
2423from nemo_curator .utils .file_utils import resolve_fs_and_path
2524from nemo_curator .utils .grouping import split_by_chunk_size
26- from nemo_curator .utils .multimodal_utils import sort_multimodal_table
25+ from nemo_curator .utils .multimodal_utils import (
26+ metadata_map_from_tables ,
27+ metadata_rows_for_table ,
28+ sort_multimodal_table ,
29+ )
2730from nemo_curator .utils .webdataset_utils import content_type_from_name
2831
2932if TYPE_CHECKING :
3033 from collections .abc import Iterable
3134
3235ReaderTask = FileGroupTask | tuple [FileGroupTask , FileGroupTask | None ]
33- _PAIR_ELEMENT_COUNT = 2
3436
3537
3638@dataclass
@@ -128,7 +130,7 @@ def _build_batches_from_tables(
128130 ) -> MultimodalBatch | list [MultimodalBatch ]:
129131 table = self ._concat_data_tables_or_empty (data_tables )
130132 table = sort_multimodal_table (table )
131- metadata_by_sample = self . _metadata_map_from_tables (metadata_tables )
133+ metadata_by_sample = metadata_map_from_tables (metadata_tables )
132134 table_splits = self .split_table (table )
133135 batches = [
134136 self ._build_batch (
@@ -181,7 +183,7 @@ def split_table_by_sample_max_bytes(self, table: pa.Table, max_batch_bytes: int)
181183 """Split table by sample groups while preserving sample row locality."""
182184 if table .num_rows == 0 :
183185 return [table ]
184- row_indices_by_sample : OrderedDict [str , list [int ]] = OrderedDict ()
186+ row_indices_by_sample : dict [str , list [int ]] = {}
185187 for idx , sample_id in enumerate (table ["sample_id" ].to_pylist ()):
186188 sid = str (sample_id )
187189 row_indices_by_sample .setdefault (sid , [])
@@ -205,6 +207,7 @@ def _text_row( # noqa: PLR0913
205207 content_type : str ,
206208 text_content : str ,
207209 element_metadata_json : str | None = None ,
210+ source_id : str | None = None ,
208211 ) -> dict [str , object ]:
209212 """Build one normalized text row payload."""
210213 return {
@@ -215,7 +218,7 @@ def _text_row( # noqa: PLR0913
215218 "text_content" : text_content ,
216219 "binary_content" : None ,
217220 "element_metadata_json" : element_metadata_json ,
218- "source_id" : sid ,
221+ "source_id" : source_id or sid ,
219222 "source_shard" : source_shard ,
220223 "content_path" : None ,
221224 "content_key" : None ,
@@ -284,24 +287,6 @@ def _task_metadata(self, task: FileGroupTask) -> dict[str, Any]:
284287 """Propagate task metadata and attach storage options used for reads."""
285288 return {** task ._metadata , "storage_options" : dict (self .storage_options )}
286289
287- @staticmethod
288- def _metadata_map_from_tables (metadata_tables : list [pa .Table ]) -> dict [str , str ]:
289- """Build first-wins sample->metadata_json map from metadata tables."""
290- metadata_by_sample : dict [str , str ] = {}
291- for metadata_table in metadata_tables :
292- has_rows = metadata_table .num_rows > 0
293- has_sample_id = "sample_id" in metadata_table .column_names
294- if has_rows and has_sample_id :
295- sample_ids = metadata_table ["sample_id" ].to_pylist ()
296- if "metadata_json" in metadata_table .column_names :
297- metadata_json_values = metadata_table ["metadata_json" ].to_pylist ()
298- else :
299- metadata_json_values = [None ] * len (sample_ids )
300- for sample_id , metadata_json in zip (sample_ids , metadata_json_values , strict = True ):
301- if isinstance (metadata_json , str ):
302- metadata_by_sample .setdefault (str (sample_id ), metadata_json )
303- return metadata_by_sample
304-
305290 def _build_batch (
306291 self ,
307292 task : FileGroupTask ,
@@ -311,7 +296,7 @@ def _build_batch(
311296 split_output : bool ,
312297 ) -> MultimodalBatch :
313298 """Assemble one ``MultimodalBatch`` from normalized data and metadata."""
314- metadata_rows = self . _metadata_rows_for_table (table , metadata_by_sample )
299+ metadata_rows = metadata_rows_for_table (table , metadata_by_sample )
315300 metadata_table = (
316301 pa .Table .from_pylist (metadata_rows , schema = METADATA_SCHEMA )
317302 if metadata_rows
@@ -326,45 +311,3 @@ def _build_batch(
326311 _metadata = self ._task_metadata (task ),
327312 _stage_perf = task ._stage_perf ,
328313 )
329-
330- @staticmethod
331- def _metadata_rows_for_table (
332- table : pa .Table ,
333- metadata_by_sample : dict [str , str ],
334- ) -> list [dict [str , object ]]:
335- """Build metadata rows with single-pass sample type inference."""
336- if table .num_rows == 0 :
337- return []
338-
339- sample_stats : OrderedDict [str , tuple [int , bool , bool ]] = OrderedDict ()
340- for sample_id , modality in zip (table ["sample_id" ].to_pylist (), table ["modality" ].to_pylist (), strict = True ):
341- sid = str (sample_id )
342- modality_name = str (modality )
343- count , has_image , has_text = sample_stats .get (sid , (0 , False , False ))
344- sample_stats [sid ] = (
345- count + 1 ,
346- has_image or modality_name == "image" ,
347- has_text or modality_name == "text" ,
348- )
349-
350- return [
351- {
352- "sample_id" : sid ,
353- "sample_type" : BaseMultimodalReaderStage ._sample_type_from_summary (
354- num_rows = num_rows ,
355- has_image = has_image ,
356- has_text = has_text ,
357- ),
358- "metadata_json" : metadata_by_sample .get (sid ),
359- }
360- for sid , (num_rows , has_image , has_text ) in sample_stats .items ()
361- ]
362-
363- @staticmethod
364- def _sample_type_from_summary (num_rows : int , has_image : bool , has_text : bool ) -> str :
365- """Infer sample type from in-sample modality ordering."""
366- if num_rows == 1 :
367- return "single"
368- if num_rows == _PAIR_ELEMENT_COUNT and has_image and has_text :
369- return "pair"
370- return "interleaved"
0 commit comments