Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ nvidia-modelopt = { git = "https://github.com/NVIDIA/TensorRT-Model-Optimizer.gi
recipes = [
"nemo-run>=0.5.0a0,<0.6.0",
]
parquet = [
"pyarrow>=14.0.0",
]
tensor-inspect = [
"nvdlfw-inspect==0.2.1",
]
Expand Down
157 changes: 120 additions & 37 deletions src/megatron/bridge/data/builders/finetuning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from megatron.core.msc_utils import MultiStorageClientFeature
from megatron.core.tokenizers.text.libraries import HuggingFaceTokenizer

from megatron.bridge.data.datasets.packed_parquet import (
is_packed_parquet_spec,
resolve_packed_parquet_paths,
)
from megatron.bridge.data.datasets.packed_sequence import PackedSequenceSpecs
from megatron.bridge.data.datasets.sft import create_sft_dataset
from megatron.bridge.training.tokenizers.tokenizer import _HuggingFaceTokenizer
Expand Down Expand Up @@ -92,37 +96,89 @@ def prepare_data(self) -> None:
self.prepare_packed_data()

def prepare_packed_data(self) -> None:
"""Prepare packed sequence data files if configured."""
if self.packed_sequence_size > 0:
from megatron.bridge.data.datasets.packed_sequence import prepare_packed_sequence_data

if not self.train_path_packed.is_file():
print_rank_0(f"Preparing packed training data at {self.train_path_packed}")
prepare_packed_sequence_data(
input_path=self.train_path,
output_path=self.train_path_packed,
packed_sequence_size=self.packed_sequence_size,
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
output_metadata_path=self.pack_metadata,
dataset_kwargs=self.dataset_kwargs,
pad_seq_to_mult=self._pad_seq_to_mult,
)

if self.do_validation and not self.validation_path_packed.is_file():
print_rank_0(f"Preparing packed validation data at {self.validation_path_packed}")
prepare_packed_sequence_data(
input_path=self.validation_path,
output_path=self.validation_path_packed,
packed_sequence_size=self.packed_sequence_size,
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
output_metadata_path=self.pack_metadata,
dataset_kwargs=self.dataset_kwargs,
pad_seq_to_mult=self._pad_seq_to_mult,
)
"""Prepare packed sequence data files if configured.

Skips preparation if:
- packed_sequence_size <= 0 (packing disabled)
- train/val paths are already packed parquet specs (externally prepared)
- train/val .npy files already exist
"""
if self.packed_sequence_size <= 0:
return

from megatron.bridge.data.datasets.packed_sequence import prepare_packed_sequence_data

# Skip if train path is already a packed parquet spec (externally prepared)
if is_packed_parquet_spec(str(self.train_path_packed)):
print_rank_0(
f"Skipping packed training data preparation - using externally prepared "
f"packed parquet: {self.train_path_packed}"
)
elif not self._packed_path_exists(self.train_path_packed):
print_rank_0(f"Preparing packed training data at {self.train_path_packed}")
prepare_packed_sequence_data(
input_path=self.train_path,
output_path=self.train_path_packed,
packed_sequence_size=self.packed_sequence_size,
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
output_metadata_path=self.pack_metadata,
dataset_kwargs=self.dataset_kwargs,
pad_seq_to_mult=self._pad_seq_to_mult,
)

if not self.do_validation:
return
Comment on lines +142 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast when automatic .npy preparation is no longer supported.

This branch only warns and returns, so the build can continue until _create_dataset() later hands back None for the missing packed path. Raising a clear error here would make the deprecation actionable.

🔧 Suggested change
         if packed_path_str.lower().endswith(".npy"):
-            warnings.warn(
-                "Automatic .npy packed sequence preparation is deprecated and will be removed in the next release. "
-                "Please use packed parquet format instead.",
-                DeprecationWarning,
-                stacklevel=3,
-            )
-            return
+            raise NotImplementedError(
+                "Automatic .npy packed sequence preparation is deprecated and no longer supported. "
+                "Please switch the packed output path to `.parquet` or `.pq`."
+            )
Based on learnings, when a feature is not supported, raise an explicit error instead of silently ignoring the input to fail fast with a clear message.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/builders/finetuning_dataset.py` around lines 142 -
150, The current check in the builder where
packed_path_str.lower().endswith(".npy") only emits a DeprecationWarning and
returns, which lets the build proceed and later fail with a None-packed path;
change this to fail fast by raising an explicit error (e.g., ValueError or
RuntimeError) in the same branch inside the method handling packed paths (the
block referencing packed_path_str and .endswith(".npy")), replacing the
warnings.warn + return with a raised exception that includes the deprecation
message and actionable guidance to use packed parquet format so callers
immediately see the unsupported input.


# Skip if val path is already a packed parquet spec (externally prepared)
if is_packed_parquet_spec(str(self.validation_path_packed)):
print_rank_0(
f"Skipping packed validation data preparation - using externally prepared "
f"packed parquet: {self.validation_path_packed}"
)
elif not self._packed_path_exists(self.validation_path_packed):
print_rank_0(f"Preparing packed validation data at {self.validation_path_packed}")
prepare_packed_sequence_data(
input_path=self.validation_path,
output_path=self.validation_path_packed,
packed_sequence_size=self.packed_sequence_size,
tokenizer=self.tokenizer,
max_seq_length=self.seq_length,
seed=self.seed,
output_metadata_path=self.pack_metadata,
dataset_kwargs=self.dataset_kwargs,
pad_seq_to_mult=self._pad_seq_to_mult,
)

def _packed_path_exists(self, path: Union[str, Path]) -> bool:
"""Check if a packed data path exists.

For .npy files: check file exists
For packed parquet specs: check if resolution returns non-empty

Args:
path: The path to check

Returns:
True if the packed data exists
"""
path_str = str(path)

# For packed parquet specs, check if resolution returns files
if is_packed_parquet_spec(path_str):
try:
resolved = resolve_packed_parquet_paths(path_str)
return len(resolved) > 0
except ValueError:
return False

# For .npy or other files, check existence
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
return msc.Path(path_str).is_file()
else:
return Path(path_str).is_file()

def build(self) -> list[Optional[Any]]:
"""Build train, validation, and test datasets.
Expand Down Expand Up @@ -191,33 +247,60 @@ def _create_dataset(
"""Create a single dataset instance (train, validation, or test).

Args:
path: Path to the dataset file
path: Path to the dataset file or packed parquet spec
pack_metadata_path: Path to the packed sequence metadata
is_test: Whether this is a test dataset
**kwargs: Additional arguments to pass to the dataset constructor

Returns:
The created dataset
"""
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
path_exists = msc.Path(path).exists()
path_str = str(path)

# Check if path exists - handle packed parquet specs differently
if is_packed_parquet_spec(path_str):
# For packed parquet specs, check via resolution
try:
resolved = resolve_packed_parquet_paths(path_str)
path_exists = len(resolved) > 0
except ValueError:
path_exists = False
else:
path_exists = Path(path).exists()
# Standard file/path existence check
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
path_exists = msc.Path(path_str).exists()
else:
path_exists = Path(path_str).exists()

if not path_exists:
print_rank_0(f"Warning: Dataset path {path} does not exist")
return None

is_not_packing = self.packed_sequence_size <= 0

# For packed parquet from external sources, only pass metadata if pad_cu_seqlens is True
# This avoids "missing metadata" errors when using externally prepared packed data
effective_metadata_path = None
if not is_not_packing:
if self._pad_cu_seqlens:
# pad_cu_seqlens requires metadata
effective_metadata_path = pack_metadata_path
elif is_packed_parquet_spec(path_str):
# Externally prepared packed parquet without pad_cu_seqlens doesn't need metadata
effective_metadata_path = None
else:
# .npy files prepared by MB include metadata
effective_metadata_path = pack_metadata_path

return create_sft_dataset(
path,
tokenizer=self.tokenizer,
seq_length=(self.seq_length if is_not_packing else self.packed_sequence_size),
memmap_workers=self.memmap_workers,
seed=self.seed,
is_test=is_test,
pack_metadata_file_path=None if is_not_packing else pack_metadata_path,
pack_metadata_file_path=effective_metadata_path,
pad_cu_seqlens=False if is_not_packing else self._pad_cu_seqlens,
**kwargs,
)
Expand Down
Loading