2121from nemo_curator .tasks import MultimodalBatch , _EmptyTask
2222from nemo_curator .tasks .multimodal import METADATA_SCHEMA , MULTIMODAL_SCHEMA
2323from nemo_curator .utils .file_utils import resolve_fs_and_path
24+ from nemo_curator .utils .multimodal_utils import cast_required_fields
2425
2526_DEFAULT_PARQUET_EXTENSIONS : Final [list [str ]] = [".parquet" ]
2627
@@ -43,8 +44,8 @@ class ParquetMultimodalReaderStage(BaseMultimodalReaderStage):
4344
4445 def __post_init__ (self ) -> None :
4546 super ().__post_init__ ()
46- self .columns = self ._validate_columns (self .columns )
47- self .metadata_columns = self ._validate_metadata_columns (self .metadata_columns )
47+ self .columns = self ._validate_column_selection (self .columns , field_name = "data.columns" )
48+ self .metadata_columns = self ._validate_column_selection (self .metadata_columns , field_name = "metadata.columns" )
4849 if self .columns is not None :
4950 missing_required = [name for name in MULTIMODAL_SCHEMA .names if name not in self .columns ]
5051 if missing_required :
@@ -55,37 +56,18 @@ def __post_init__(self) -> None:
5556 raise ValueError (msg )
5657
5758 @staticmethod
58- def _validate_columns (columns : list [str ] | None ) -> list [str ] | None :
59- """Validate optional data column selection."""
59+ def _validate_column_selection (columns : list [str ] | None , * , field_name : str ) -> list [str ] | None :
60+ """Validate optional column selection for data and metadata inputs ."""
6061 if columns is None :
6162 return None
6263 if len (columns ) == 0 :
63- msg = "columns must be a non-empty list when provided"
64+ msg = f" { field_name } must be a non-empty list when provided"
6465 raise ValueError (msg )
6566 seen : set [str ] = set ()
6667 normalized : list [str ] = []
6768 for column in columns :
6869 if not isinstance (column , str ) or not column :
69- msg = "columns entries must be non-empty strings"
70- raise ValueError (msg )
71- if column not in seen :
72- seen .add (column )
73- normalized .append (column )
74- return normalized
75-
76- @staticmethod
77- def _validate_metadata_columns (columns : list [str ] | None ) -> list [str ] | None :
78- """Validate optional metadata sidecar column selection."""
79- if columns is None :
80- return None
81- if len (columns ) == 0 :
82- msg = "metadata_columns must be a non-empty list when provided"
83- raise ValueError (msg )
84- seen : set [str ] = set ()
85- normalized : list [str ] = []
86- for column in columns :
87- if not isinstance (column , str ) or not column :
88- msg = "metadata_columns entries must be non-empty strings"
70+ msg = f"{ field_name } entries must be non-empty strings"
8971 raise ValueError (msg )
9072 if column not in seen :
9173 seen .add (column )
@@ -114,7 +96,7 @@ def _read_metadata_table(self, metadata_path: str) -> pa.Table:
11496 def _normalize_data_table (self , table : pa .Table ) -> pa .Table :
11597 source = self ._ensure_optional_string_column (table , "element_metadata_json" )
11698 self ._ensure_required_columns (source , MULTIMODAL_SCHEMA .names )
117- return self . _cast_required_fields (source , MULTIMODAL_SCHEMA )
99+ return cast_required_fields (source , MULTIMODAL_SCHEMA )
118100
119101 def _normalize_metadata_table (self , table : pa .Table ) -> pa .Table :
120102 self ._ensure_required_columns (
@@ -124,19 +106,7 @@ def _normalize_metadata_table(self, table: pa.Table) -> pa.Table:
124106 )
125107 source = self ._ensure_optional_string_column (table , "sample_type" )
126108 source = self ._ensure_optional_string_column (source , "metadata_json" )
127- return self ._cast_required_fields (source , METADATA_SCHEMA )
128-
129- @staticmethod
130- def _cast_required_fields (table : pa .Table , required_schema : pa .Schema ) -> pa .Table :
131- """Cast required fields in-place while preserving any extra columns."""
132- out = table
133- for required_field in required_schema :
134- col_idx = out .schema .get_field_index (required_field .name )
135- if col_idx >= 0 :
136- col = out [required_field .name ]
137- if not col .type .equals (required_field .type ):
138- out = out .set_column (col_idx , required_field .name , col .cast (required_field .type ))
139- return out
109+ return cast_required_fields (source , METADATA_SCHEMA )
140110
141111 @staticmethod
142112 def _ensure_optional_string_column (table : pa .Table , column_name : str ) -> pa .Table :
0 commit comments