Skip to content

Commit 7831a78

Browse files
authored
Merge branch 'main' into chcui/gpt-oss-thd
2 parents 87c4dda + ea844b9 commit 7831a78

File tree

34 files changed

+532
-135
lines changed

34 files changed

+532
-135
lines changed

docs/training/packed-sequences.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ The {py:class}`bridge.data.datasets.packed_sequence.PackedSequenceSpecs` class p
6060
| `packed_train_data_path` | `str` | `None` | Custom path for packed training dataset file (`.npy` format). |
6161
| `packed_val_data_path` | `str` | `None` | Custom path for packed validation dataset file (`.npy` format). |
6262
| `packed_metadata_path` | `str` | `None` | Custom path for packing metadata file (`.jsonl` format). |
63+
| `pad_seq_to_mult` | `int \| None` | `None` | Pad each sample to a multiple of this value when generating packed datasets (e.g., set to `2 * context_parallel_size` for THD CP). |
6364
| `pad_cu_seqlens` | `bool` | `False` | Whether to pad `cu_seqlens` to constant size, required for CUDA graphs. |
6465

6566
### Batch Size Considerations

scripts/performance/run_recipe.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def set_user_overrides(config, args):
6767
# Dataset configuration
6868
logging.info(f"Configuring dataset: type={args.data}")
6969

70+
cp_size = getattr(config.model, "context_parallel_size", 1) or 1
71+
pad_seq_to_mult = cp_size * 2 if cp_size > 1 else 1
72+
7073
# Create dataset configuration based on type
7174
if args.data == "mock":
7275
config.dataset = create_mock_dataset_config(seq_length=args.seq_length or 8192)
@@ -82,13 +85,19 @@ def set_user_overrides(config, args):
8285
if not args.dataset_root:
8386
raise ValueError("--dataset-root is required for squad dataset")
8487
config.dataset = create_squad_dataset_config(
85-
dataset_root=args.dataset_root, seq_length=args.seq_length or 8192, packed=False
88+
dataset_root=args.dataset_root,
89+
seq_length=args.seq_length or 8192,
90+
packed=False,
91+
pad_seq_to_mult=pad_seq_to_mult,
8692
)
8793
elif args.data == "squad_packed":
8894
if not args.dataset_root:
8995
raise ValueError("--dataset-root is required for squad_packed dataset")
9096
config.dataset = create_squad_dataset_config(
91-
dataset_root=args.dataset_root, seq_length=args.seq_length or 8192, packed=True
97+
dataset_root=args.dataset_root,
98+
seq_length=args.seq_length or 8192,
99+
packed=True,
100+
pad_seq_to_mult=pad_seq_to_mult,
92101
)
93102
else:
94103
raise ValueError(f"Unknown dataset type: {args.data}")

scripts/performance/utils/datasets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def create_rp2_dataset_config(dataset_paths, seq_length, index_mapping_dir=None)
5858
)
5959

6060

61-
def create_squad_dataset_config(dataset_root, seq_length, packed=False):
61+
def create_squad_dataset_config(dataset_root, seq_length, packed=False, pad_seq_to_mult=1):
6262
"""Create SQuAD dataset configuration for Megatron-Bridge using HF dataset."""
6363
from megatron.bridge.data.builders.hf_dataset import HFDatasetConfig
6464
from megatron.bridge.data.datasets.packed_sequence import PackedSequenceSpecs
@@ -67,7 +67,7 @@ def create_squad_dataset_config(dataset_root, seq_length, packed=False):
6767
# Create packed sequence specs if needed
6868
packed_sequence_specs = None
6969
if packed:
70-
packed_sequence_specs = PackedSequenceSpecs(packed_sequence_size=seq_length)
70+
packed_sequence_specs = PackedSequenceSpecs(packed_sequence_size=seq_length, pad_seq_to_mult=pad_seq_to_mult)
7171

7272
return HFDatasetConfig(
7373
dataset_name="squad", # Hugging Face dataset name

scripts/performance/utils/overrides.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,24 @@ def set_user_overrides(recipe: ConfigContainer, args: argparse.Namespace) -> Con
287287
elif args.data == "squad":
288288
if not args.dataset_root:
289289
raise ValueError("--dataset-root is required for squad dataset")
290+
cp_size = getattr(recipe.model, "context_parallel_size", 1) or 1
291+
pad_seq_to_mult = cp_size * 2 if cp_size > 1 else 1
290292
recipe.dataset = create_squad_dataset_config(
291-
dataset_root=args.dataset_root, seq_length=args.seq_length or recipe.model.seq_length, packed=False
293+
dataset_root=args.dataset_root,
294+
seq_length=args.seq_length or recipe.model.seq_length,
295+
packed=False,
296+
pad_seq_to_mult=pad_seq_to_mult,
292297
)
293298
elif args.data == "squad_packed":
294299
if not args.dataset_root:
295300
raise ValueError("--dataset-root is required for squad_packed dataset")
301+
cp_size = getattr(recipe.model, "context_parallel_size", 1) or 1
302+
pad_seq_to_mult = cp_size * 2 if cp_size > 1 else 1
296303
recipe.dataset = create_squad_dataset_config(
297-
dataset_root=args.dataset_root, seq_length=args.seq_length or recipe.model.seq_length, packed=True
304+
dataset_root=args.dataset_root,
305+
seq_length=args.seq_length or recipe.model.seq_length,
306+
packed=True,
307+
pad_seq_to_mult=pad_seq_to_mult,
298308
)
299309
if recipe.model.cuda_graph_impl != "none":
300310
recipe.dataset.packed_sequence_specs.pad_cu_seqlens = True

src/megatron/bridge/data/builders/finetuning_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
self.packed_sequence_size = -1 if not packed_sequence_specs else packed_sequence_specs.packed_sequence_size
7878
self.dataset_kwargs = dataset_kwargs or {}
7979
self._pad_cu_seqlens = False if not packed_sequence_specs else packed_sequence_specs.pad_cu_seqlens
80+
self._pad_seq_to_mult = None if not packed_sequence_specs else packed_sequence_specs.pad_seq_to_mult
8081

8182
self.do_validation = do_validation
8283
self.do_test = do_test
@@ -106,6 +107,7 @@ def prepare_packed_data(self) -> None:
106107
seed=self.seed,
107108
output_metadata_path=self.pack_metadata,
108109
dataset_kwargs=self.dataset_kwargs,
110+
pad_seq_to_mult=self._pad_seq_to_mult,
109111
)
110112

111113
if self.do_validation and not self.validation_path_packed.is_file():
@@ -119,6 +121,7 @@ def prepare_packed_data(self) -> None:
119121
seed=self.seed,
120122
output_metadata_path=self.pack_metadata,
121123
dataset_kwargs=self.dataset_kwargs,
124+
pad_seq_to_mult=self._pad_seq_to_mult,
122125
)
123126

124127
def build(self) -> list[Optional[Any]]:
@@ -235,7 +238,9 @@ def default_pack_path(self) -> Path:
235238
The Path object for the default packing directory.
236239
"""
237240
tokenizer_model_name = self._extract_tokenizer_model_name()
238-
default_pack_path = self.dataset_root / "packed" / tokenizer_model_name
241+
default_pack_path = (
242+
self.dataset_root / "packed" / f"{tokenizer_model_name}_pad_seq_to_mult{self._pad_seq_to_mult}"
243+
)
239244
if not default_pack_path.exists():
240245
default_pack_path.mkdir(parents=True, exist_ok=True)
241246
logger.info(f"Using default path for packing files: {str(default_pack_path)}")

src/megatron/bridge/data/datasets/packed_sequence.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def tokenize_dataset(
3333
max_seq_length: int,
3434
seed: int,
3535
dataset_kwargs: dict | None = None,
36+
pad_seq_to_mult: int | None = 1,
3637
):
3738
"""
3839
Tokenizes a dataset from the provided path using the specified tokenizer
@@ -45,6 +46,8 @@ def tokenize_dataset(
4546
seed (int): Random seed for shuffling the dataset.
4647
dataset_kwargs (dict | None): Additional keyword arguments to pass to create_sft_dataset.
4748
Can include 'chat', 'use_hf_tokenizer_chat_template', 'tool_schemas', etc.
49+
pad_seq_to_mult (int | None): Optional multiple to pad each sequence to during packing
50+
preparation (e.g., set to 2 * context_parallel_size for THD CP).
4851
4952
Returns:
5053
np.ndarray: A NumPy array containing the tokenized data.
@@ -66,15 +69,56 @@ def tokenize_dataset(
6669
if hasattr(tokenizer, "_tokenizer"):
6770
tokenizer._tokenizer.chat_template = chat_template
6871

72+
if pad_seq_to_mult is not None and pad_seq_to_mult <= 0:
73+
raise ValueError("pad_seq_to_mult must be a positive integer when provided.")
74+
75+
# Keep the historical minimum of 16 unless a larger multiple is requested.
76+
pad_seq_length_to_mult = 1 if pad_seq_to_mult is None else max(1, pad_seq_to_mult)
77+
6978
dataset = create_sft_dataset(
7079
path=path,
7180
tokenizer=tokenizer,
7281
seq_length=max_seq_length,
7382
seed=seed,
7483
is_test=True,
84+
pad_seq_length_to_mult=pad_seq_length_to_mult,
7585
**dataset_kwargs,
7686
)
77-
return np.array([dataset[i] for i in range(len(dataset))])
87+
88+
pad_id = dataset.tokenizer.eod
89+
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult
90+
max_seq_length = dataset.max_seq_length
91+
dataset = np.array([dataset[i] for i in range(len(dataset))])
92+
93+
if pad_seq_to_mult > 1:
94+
95+
def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id):
96+
"""
97+
Pad each individual data point to the length of max_length_to_pad.
98+
This keeps packed samples divisible by the requested multiple (used for CP/THD).
99+
"""
100+
assert max_seq_length >= max_length_to_pad
101+
for key, val in data.items():
102+
if key in {"input_ids", "context_ids"}:
103+
if len(val) <= max_length_to_pad:
104+
# input_ids are truncated by 1 for labels; add 1 extra pad token
105+
val = val + [pad_id] * (max_length_to_pad - len(val) + 1)
106+
elif len(val) > max_seq_length:
107+
logging.info(
108+
"Sequence length %d is larger than max_seq_length %d; truncating for packing.",
109+
len(val),
110+
max_seq_length,
111+
)
112+
val = val[:max_seq_length]
113+
data[key] = val
114+
return
115+
116+
ceil_to_nearest = lambda n, m: (n + m - 1) // m * m
117+
for data in dataset:
118+
max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data["input_ids"]), pad_seq_length_to_mult))
119+
pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id)
120+
121+
return dataset
78122

79123

80124
def prepare_packed_sequence_data(
@@ -87,6 +131,7 @@ def prepare_packed_sequence_data(
87131
seed: int | None = 0,
88132
packing_algorithm: str = "first_fit_shuffle",
89133
dataset_kwargs: dict | None = None,
134+
pad_seq_to_mult: int | None = 1,
90135
):
91136
"""
92137
Prepares a packed sequence dataset from a given input file and saves it to an output file.
@@ -103,12 +148,21 @@ def prepare_packed_sequence_data(
103148
currently supports "first_fit_shuffle" and "first_fit_decreasing".
104149
dataset_kwargs (dict | None): Additional keyword arguments to pass to create_sft_dataset.
105150
Enables packing with chat templates, tool schemas, etc.
151+
pad_seq_to_mult (int | None): Optional multiple to pad each sequence to during packing
152+
preparation (e.g., set to 2 * context_parallel_size for THD CP).
106153
107154
Returns:
108155
None: Saves the packed sequence data to the specified output path.
109156
"""
110157
logger.info(f"Preparing packed sequence from {input_path}")
111-
dataset = tokenize_dataset(input_path, tokenizer, max_seq_length, seed, dataset_kwargs)
158+
dataset = tokenize_dataset(
159+
input_path,
160+
tokenizer,
161+
max_seq_length,
162+
seed,
163+
dataset_kwargs,
164+
pad_seq_to_mult=pad_seq_to_mult,
165+
)
112166
sequences, histogram = create_hist(dataset, max_seq_length)
113167

114168
assignments, packing_metadata = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm)
@@ -185,6 +239,11 @@ class PackedSequenceSpecs:
185239
"""
186240
If True, pad cu_seqlens to a constant size, which is required for use with cudagraphs.
187241
"""
242+
pad_seq_to_mult: int | None = 1
243+
"""
244+
Optional multiple to pad each sample to when generating packed datasets.
245+
For THD/context parallel, set to (context_parallel_size * 2) to keep samples divisible.
246+
"""
188247

189248
def __post_init__(self):
190249
if self.packed_train_data_path is not None:
@@ -212,3 +271,6 @@ def __post_init__(self):
212271
assert self.packed_val_data_path.exists(), (
213272
f"packed validation data file does not exist: {self.packed_val_data_path}"
214273
)
274+
275+
if self.pad_seq_to_mult is not None and self.pad_seq_to_mult <= 0:
276+
raise ValueError("pad_seq_to_mult must be a positive integer when provided.")

src/megatron/bridge/data/datasets/sft.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def __init__(
225225
output_original_text: bool = False,
226226
ceil_to_power_2: bool = False,
227227
get_attention_mask_from_fusion: bool = True,
228-
sanity_check_dist_workers: bool = True,
229228
):
230229
"""
231230
file_path: Path to a JSONL GPT supervised fine-tuning dataset.
@@ -274,7 +273,6 @@ def __init__(
274273
output_original_text (bool): if true, will keep the original text in the output alongside the tokenized ids.
275274
get_attention_mask_from_fusion (bool): if true, lets attention kernel handle creation of causal mask instead
276275
of adding it to the batch dict.
277-
sanity_check_dist_workers (bool): if true, will run sanity check across workers when making mapping.
278276
"""
279277
self.tokenizer = tokenizer
280278
self.file_path = file_path
@@ -303,7 +301,6 @@ def __init__(
303301
self.output_original_text = output_original_text
304302
self.ceil_to_power_2 = ceil_to_power_2
305303
self.get_attention_mask_from_fusion = get_attention_mask_from_fusion
306-
self.sanity_check_dist_workers = sanity_check_dist_workers
307304

308305
if special_tokens is None:
309306
self.special_tokens = {
@@ -385,7 +382,6 @@ def _build_samples_mapping(self):
385382
binary_head=False,
386383
index_mapping_dir=self.index_mapping_dir,
387384
samples_mapping=osm,
388-
sanity_check_dist_workers=self.sanity_check_dist_workers,
389385
)
390386
else:
391387
self.samples_mapping = None
@@ -914,13 +910,10 @@ def collate_fn(self, batch):
914910
for i in range(len(item["seq_boundaries"]) - 1):
915911
current_seq = item["input_ids"][item["seq_boundaries"][i] : item["seq_boundaries"][i + 1] - 1]
916912

917-
# since the data could be prepadded with tokenizer's eos_id,
918-
# we can find out the index of all the eos_id
919-
eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id)
920-
921-
# The second eos_id index marks the length of the original unpadded sequence if the sequence is
922-
# prepadded for cp_size > 1. Otherwise, there is no extra padding.
923-
seqlen_unpadded = eos_idx[0][1] + 1 if eos_idx[0].shape[0] > 1 else len(current_seq)
913+
# Stop unpadded lengths at the last non-eos token so padding eos are excluded.
914+
current_seq_arr = np.array(current_seq)
915+
non_eos_positions = np.where(current_seq_arr != self.tokenizer.eos_id)[0]
916+
seqlen_unpadded = non_eos_positions[-1] + 1 if non_eos_positions.size > 0 else 0
924917
cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded)
925918

926919
# if extra paddings are added in the packed sequence, they can't be counted as
@@ -944,10 +937,15 @@ def collate_fn(self, batch):
944937
loss_mask = self._collate_item(loss_mask, max_length=max_length, pad_id=0)
945938
position_ids = self._collate_item(position_ids, max_length=max_length, pad_id=0)
946939

940+
tokens = torch.LongTensor(input_ids)
941+
loss_mask = torch.LongTensor(loss_mask)
942+
# drop any padding/eos tokens from contributing to the loss
943+
loss_mask[tokens == self.tokenizer.eos_id] = 0
944+
947945
processed_batch = {
948-
"tokens": torch.LongTensor(input_ids),
946+
"tokens": tokens,
949947
"labels": torch.LongTensor(labels),
950-
"loss_mask": torch.LongTensor(loss_mask),
948+
"loss_mask": loss_mask,
951949
"position_ids": torch.LongTensor(position_ids),
952950
"token_count": token_count,
953951
}

src/megatron/bridge/data/datasets/utils.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -731,11 +731,10 @@ def _get_samples_mapping(
731731
binary_head,
732732
index_mapping_dir: str = None,
733733
samples_mapping: Any = None,
734-
sanity_check_dist_workers: bool = True,
735734
):
736735
"""Get a list that maps a sample index to a starting sentence index, end sentence index, and length"""
737-
738-
from megatron.core import parallel_state
736+
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
737+
rank = torch.distributed.get_rank() if is_distributed else 0
739738

740739
if not num_epochs:
741740
if not max_num_samples:
@@ -760,7 +759,7 @@ def _get_samples_mapping(
760759
indexmap_filename += ".npy"
761760

762761
# Build the indexed mapping if not exist and not provided externally.
763-
if samples_mapping is None and torch.distributed.get_rank() == 0 and not os.path.isfile(indexmap_filename):
762+
if samples_mapping is None and rank == 0 and not os.path.isfile(indexmap_filename):
764763
# Fake index mapping if missing
765764
if (getattr(indexed_dataset, "doc_idx", None) is None) and (getattr(indexed_dataset, "sizes", None) is None):
766765
_make_indexed_dataset_compatibility(indexed_dataset)
@@ -776,7 +775,7 @@ def _get_samples_mapping(
776775
assert indexed_dataset.sizes.dtype == np.int32
777776

778777
# Build samples mapping
779-
verbose = torch.distributed.get_rank() == 0
778+
verbose = rank == 0
780779
start_time = time.time()
781780
logger.info(" > building samples index mapping for {} ...".format(name))
782781
# First compile and then import.
@@ -806,15 +805,11 @@ def _get_samples_mapping(
806805
" > elasped time to build and save samples mapping (seconds): {:4f}".format(time.time() - start_time)
807806
)
808807

809-
if sanity_check_dist_workers:
808+
# Ensure the mapping exists before all ranks attempt to load it.
809+
# Skip barrier when invoked from a rank-0-only data preparation flow (see `rank_0_prepare_data()`).
810+
if is_distributed and not rank_0_prepare_data():
810811
torch.distributed.barrier()
811-
counts = torch.cuda.LongTensor([1])
812-
torch.distributed.all_reduce(counts, group=parallel_state.get_data_parallel_group(with_context_parallel=True))
813-
torch.distributed.all_reduce(counts, group=parallel_state.get_pipeline_model_parallel_group())
814-
assert counts[0].item() == (
815-
torch.distributed.get_world_size()
816-
// torch.distributed.get_world_size(group=parallel_state.get_tensor_model_parallel_group())
817-
)
812+
818813
# Load indexed dataset if not given externally.
819814
if samples_mapping is None:
820815
logger.info(" > loading indexed mapping from {}".format(indexmap_filename))

0 commit comments

Comments
 (0)