Skip to content

Commit 39617df

Browse files
committed
Refactor multimodal column validation and shared schema casting
1 parent 793c4bb commit 39617df

File tree

3 files changed

+24
-54
lines changed

3 files changed

+24
-54
lines changed

nemo_curator/stages/multimodal/io/readers/parquet.py

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nemo_curator.tasks import MultimodalBatch, _EmptyTask
2222
from nemo_curator.tasks.multimodal import METADATA_SCHEMA, MULTIMODAL_SCHEMA
2323
from nemo_curator.utils.file_utils import resolve_fs_and_path
24+
from nemo_curator.utils.multimodal_utils import cast_required_fields
2425

2526
_DEFAULT_PARQUET_EXTENSIONS: Final[list[str]] = [".parquet"]
2627

@@ -43,8 +44,8 @@ class ParquetMultimodalReaderStage(BaseMultimodalReaderStage):
4344

4445
def __post_init__(self) -> None:
4546
super().__post_init__()
46-
self.columns = self._validate_columns(self.columns)
47-
self.metadata_columns = self._validate_metadata_columns(self.metadata_columns)
47+
self.columns = self._validate_column_selection(self.columns, field_name="data.columns")
48+
self.metadata_columns = self._validate_column_selection(self.metadata_columns, field_name="metadata.columns")
4849
if self.columns is not None:
4950
missing_required = [name for name in MULTIMODAL_SCHEMA.names if name not in self.columns]
5051
if missing_required:
@@ -55,37 +56,18 @@ def __post_init__(self) -> None:
5556
raise ValueError(msg)
5657

5758
@staticmethod
58-
def _validate_columns(columns: list[str] | None) -> list[str] | None:
59-
"""Validate optional data column selection."""
59+
def _validate_column_selection(columns: list[str] | None, *, field_name: str) -> list[str] | None:
60+
"""Validate optional column selection for data and metadata inputs."""
6061
if columns is None:
6162
return None
6263
if len(columns) == 0:
63-
msg = "columns must be a non-empty list when provided"
64+
msg = f"{field_name} must be a non-empty list when provided"
6465
raise ValueError(msg)
6566
seen: set[str] = set()
6667
normalized: list[str] = []
6768
for column in columns:
6869
if not isinstance(column, str) or not column:
69-
msg = "columns entries must be non-empty strings"
70-
raise ValueError(msg)
71-
if column not in seen:
72-
seen.add(column)
73-
normalized.append(column)
74-
return normalized
75-
76-
@staticmethod
77-
def _validate_metadata_columns(columns: list[str] | None) -> list[str] | None:
78-
"""Validate optional metadata sidecar column selection."""
79-
if columns is None:
80-
return None
81-
if len(columns) == 0:
82-
msg = "metadata_columns must be a non-empty list when provided"
83-
raise ValueError(msg)
84-
seen: set[str] = set()
85-
normalized: list[str] = []
86-
for column in columns:
87-
if not isinstance(column, str) or not column:
88-
msg = "metadata_columns entries must be non-empty strings"
70+
msg = f"{field_name} entries must be non-empty strings"
8971
raise ValueError(msg)
9072
if column not in seen:
9173
seen.add(column)
@@ -114,7 +96,7 @@ def _read_metadata_table(self, metadata_path: str) -> pa.Table:
11496
def _normalize_data_table(self, table: pa.Table) -> pa.Table:
11597
source = self._ensure_optional_string_column(table, "element_metadata_json")
11698
self._ensure_required_columns(source, MULTIMODAL_SCHEMA.names)
117-
return self._cast_required_fields(source, MULTIMODAL_SCHEMA)
99+
return cast_required_fields(source, MULTIMODAL_SCHEMA)
118100

119101
def _normalize_metadata_table(self, table: pa.Table) -> pa.Table:
120102
self._ensure_required_columns(
@@ -124,19 +106,7 @@ def _normalize_metadata_table(self, table: pa.Table) -> pa.Table:
124106
)
125107
source = self._ensure_optional_string_column(table, "sample_type")
126108
source = self._ensure_optional_string_column(source, "metadata_json")
127-
return self._cast_required_fields(source, METADATA_SCHEMA)
128-
129-
@staticmethod
130-
def _cast_required_fields(table: pa.Table, required_schema: pa.Schema) -> pa.Table:
131-
"""Cast required fields in-place while preserving any extra columns."""
132-
out = table
133-
for required_field in required_schema:
134-
col_idx = out.schema.get_field_index(required_field.name)
135-
if col_idx >= 0:
136-
col = out[required_field.name]
137-
if not col.type.equals(required_field.type):
138-
out = out.set_column(col_idx, required_field.name, col.cast(required_field.type))
139-
return out
109+
return cast_required_fields(source, METADATA_SCHEMA)
140110

141111
@staticmethod
142112
def _ensure_optional_string_column(table: pa.Table, column_name: str) -> pa.Table:

nemo_curator/stages/multimodal/io/writers/base.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
resolve_sidecar_output_path,
2626
resolve_task_scoped_output_path,
2727
)
28-
from nemo_curator.utils.multimodal_utils import sort_multimodal_table
28+
from nemo_curator.utils.multimodal_utils import cast_required_fields, sort_multimodal_table
2929

3030
_SUPPORTED_WRITE_MODES: set[str] = {"overwrite", "error", "ignore"}
3131

@@ -137,19 +137,7 @@ def _build_output_table(self, task: MultimodalBatch) -> pa.Table:
137137
if missing:
138138
msg = f"{self.__class__.__name__} requires columns: {missing}"
139139
raise ValueError(msg)
140-
return self._cast_required_fields(table, MULTIMODAL_SCHEMA)
141-
142-
@staticmethod
143-
def _cast_required_fields(table: pa.Table, required_schema: pa.Schema) -> pa.Table:
144-
"""Cast required fields in-place while preserving any extra columns."""
145-
out = table
146-
for required_field in required_schema:
147-
col_idx = out.schema.get_field_index(required_field.name)
148-
if col_idx >= 0:
149-
col = out[required_field.name]
150-
if not col.type.equals(required_field.type):
151-
out = out.set_column(col_idx, required_field.name, col.cast(required_field.type))
152-
return out
140+
return cast_required_fields(table, MULTIMODAL_SCHEMA)
153141

154142
def _write_tabular(self, table: pa.Table, output_path: str, format_name: Literal["parquet", "arrow"]) -> None:
155143
"""Write Arrow table to parquet or arrow artifact path."""
@@ -201,7 +189,7 @@ def _build_metadata_table(task: MultimodalBatch) -> pa.Table:
201189
column_name,
202190
pa.nulls(metadata_table.num_rows, type=pa.string()),
203191
)
204-
metadata_table = BaseMultimodalWriterStage._cast_required_fields(metadata_table, METADATA_SCHEMA)
192+
metadata_table = cast_required_fields(metadata_table, METADATA_SCHEMA)
205193
return metadata_table.sort_by([("sample_id", "ascending")])
206194

207195
@staticmethod

nemo_curator/utils/multimodal_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@
1717
import pyarrow as pa
1818

1919

20+
def cast_required_fields(table: pa.Table, required_schema: pa.Schema) -> pa.Table:
21+
"""Cast required fields in-place while preserving any extra columns."""
22+
out = table
23+
for required_field in required_schema:
24+
col_idx = out.schema.get_field_index(required_field.name)
25+
if col_idx >= 0:
26+
col = out[required_field.name]
27+
if not col.type.equals(required_field.type):
28+
out = out.set_column(col_idx, required_field.name, col.cast(required_field.type))
29+
return out
30+
31+
2032
def validate_content_path_loading_mode(*, content_path: str, row_indices: list[int], content_keys: list[object | None]) -> None:
2133
"""Ensure one content_path group uses a single loading mode."""
2234
for idx in row_indices:

0 commit comments

Comments
 (0)