Skip to content

Commit fb875ac

Browse files
authored
Adding modelopt_run_config.yaml and a main function for megatron data preprocessing (#341)
Signed-off-by: Chenhan Yu <[email protected]>
1 parent be95a10 commit fb875ac

File tree

2 files changed

+129
-0
lines changed

2 files changed

+129
-0
lines changed

modelopt/torch/opt/plugins/mcore_dist_checkpointing.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any
2323

2424
import torch
25+
import yaml
2526
from megatron.core import dist_checkpointing, mpu
2627
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
2728
from megatron.core.dist_checkpointing.strategies.common import COMMON_STATE_FNAME
@@ -35,6 +36,21 @@
3536

3637
SUPPORTED_WRAPPERS[Float16Module] = "module"
3738

39+
DROP_SUBSTRINGS = [
40+
"fp4",
41+
"fp8",
42+
"tp_",
43+
"parallel",
44+
"cuda_graph",
45+
"init_",
46+
"cpu",
47+
"recompute",
48+
"inference",
49+
"pipeline",
50+
"comm",
51+
"batch",
52+
]
53+
3854

3955
def remove_per_module_state(
4056
modelopt_state: dict[str, Any],
@@ -122,6 +138,27 @@ def save_sharded_modelopt_state(
122138
sharded_strategy: configures sharded tensors saving behavior and backend
123139
prefix: the prefix to add to the modelopt_state keys ("model." for NeMo)
124140
"""
141+
142+
def _parse_transformer_config(transformer_config: dict) -> dict:
143+
config = {}
144+
145+
for k, v in transformer_config.items():
146+
if any(substring in k for substring in DROP_SUBSTRINGS):
147+
continue
148+
if isinstance(v, (bool, int, str)):
149+
config[k] = v
150+
else:
151+
config[k] = str(v)
152+
153+
return config
154+
155+
if dist.is_master():
156+
run_config_name = f"{checkpoint_name}/modelopt_run_config.yaml"
157+
config_dict = _parse_transformer_config(copy.deepcopy(model[0].config.__dict__))
158+
config_dict["nvidia_modelopt_version"] = modelopt.__version__
159+
with open(run_config_name, "w") as f:
160+
yaml.dump(config_dict, f, default_flow_style=False)
161+
125162
if not mto.ModeloptStateManager.is_converted(model[0]):
126163
return
127164
if len(model) > 1:

modelopt/torch/utils/plugins/megatron_preprocess_data.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@
3131
```
3232
"""
3333

34+
import argparse
3435
import json
3536
import multiprocessing
3637
import sys
3738
from pathlib import Path
3839

40+
import requests
41+
from datasets import load_dataset
3942
from megatron.core.datasets import indexed_dataset
4043
from transformers import AutoTokenizer
4144

@@ -198,3 +201,92 @@ def megatron_preprocess_data(
198201
final_enc_len += num_tokens
199202

200203
print(f">>> Total number of tokens: {final_enc_len}")
204+
205+
206+
def main():
207+
"""Sample main function to process large data for pretraining.
208+
209+
Example usage:
210+
211+
>>> python megatron_preprocess_data.py \
212+
--dataset "nvidia/Nemotron-Pretraining-Dataset-sample" \
213+
--tokenizer "meta-llama/Llama-3.2-1B-Instruct" \
214+
--output_dir "./processed_data"
215+
"""
216+
parser = argparse.ArgumentParser(prog="megatron_preprocess_data")
217+
parser.add_argument("--input_path", type=str, default=None, help="Input path.")
218+
parser.add_argument(
219+
"--dataset",
220+
type=str,
221+
default="nvidia/Nemotron-Pretraining-Dataset-sample",
222+
help="Hugging Face Hub dataset name or path",
223+
)
224+
parser.add_argument("--subset", type=str, default=None, help="Hugging Face Hub dataset subset")
225+
parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split")
226+
parser.add_argument(
227+
"--output_dir", type=str, default="./processed_data", help="Output directory"
228+
)
229+
parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path")
230+
parser.add_argument("--json_keys", nargs="+", default=["text"], help="JSON keys to tokenize")
231+
parser.add_argument("--append_eod", action="store_true", help="Append <eod> token")
232+
parser.add_argument(
233+
"--max_sequence_length", type=int, default=None, help="Maximum sequence length"
234+
)
235+
parser.add_argument("--workers", type=int, default=8, help="Number of worker processes")
236+
parser.add_argument("--log_interval", type=int, default=1000, help="Log interval")
237+
args = parser.parse_args()
238+
239+
if args.input_path is None:
240+
args.input_path = []
241+
242+
try:
243+
response = requests.get(
244+
f"https://datasets-server.huggingface.co/splits?dataset={args.dataset}",
245+
timeout=10,
246+
)
247+
response.raise_for_status()
248+
except requests.RequestException as e:
249+
print(f"Failed to fetch dataset splits for {args.dataset}: {e}")
250+
return
251+
252+
for entry in response.json()["splits"]:
253+
skip_processing = False
254+
name = entry["dataset"]
255+
subset = entry.get("config", None)
256+
split = entry["split"]
257+
258+
if args.subset is not None and args.subset != subset:
259+
skip_processing = True
260+
if args.split is not None and args.split != split:
261+
skip_processing = True
262+
263+
print(f"Loading dataset {name} with subset {subset} and split {split}")
264+
dataset = load_dataset(name, subset, split=split)
265+
266+
for key in args.json_keys:
267+
if key not in dataset.features:
268+
print(f"Key {key} not found in dataset features. Skipping...")
269+
skip_processing = True
270+
break
271+
272+
if skip_processing:
273+
continue
274+
275+
json_file_path = args.output_dir + "/" + name + "_" + subset + "_" + split + ".jsonl"
276+
dataset.to_json(json_file_path)
277+
args.input_path += [json_file_path]
278+
279+
megatron_preprocess_data(
280+
input_path=args.input_path,
281+
output_dir=args.output_dir,
282+
tokenizer_name_or_path=args.tokenizer,
283+
json_keys=args.json_keys,
284+
append_eod=args.append_eod,
285+
max_sequence_length=args.max_sequence_length,
286+
workers=args.workers,
287+
log_interval=args.log_interval,
288+
)
289+
290+
291+
if __name__ == "__main__":
292+
main()

0 commit comments

Comments
 (0)