diff --git a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py index c2f5bca7..f0901955 100644 --- a/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py +++ b/modelopt/torch/opt/plugins/mcore_dist_checkpointing.py @@ -22,6 +22,7 @@ from typing import Any import torch +import yaml from megatron.core import dist_checkpointing, mpu from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy from megatron.core.dist_checkpointing.strategies.common import COMMON_STATE_FNAME @@ -35,6 +36,21 @@ SUPPORTED_WRAPPERS[Float16Module] = "module" +DROP_SUBSTRINGS = [ + "fp4", + "fp8", + "tp_", + "parallel", + "cuda_graph", + "init_", + "cpu", + "recompute", + "inference", + "pipeline", + "comm", + "batch", +] + def remove_per_module_state( modelopt_state: dict[str, Any], @@ -122,6 +138,27 @@ def save_sharded_modelopt_state( sharded_strategy: configures sharded tensors saving behavior and backend prefix: the prefix to add to the modelopt_state keys ("model." for NeMo) """ + + def _parse_transformer_config(transformer_config: dict) -> dict: + config = {} + + for k, v in transformer_config.items(): + if any(substring in k for substring in DROP_SUBSTRINGS): + continue + if isinstance(v, (bool, int, str)): + config[k] = v + else: + config[k] = str(v) + + return config + + if dist.is_master(): + run_config_name = f"{checkpoint_name}/modelopt_run_config.yaml" + config_dict = _parse_transformer_config(copy.deepcopy(model[0].config.__dict__)) + config_dict["nvidia_modelopt_version"] = modelopt.__version__ + with open(run_config_name, "w") as f: + yaml.dump(config_dict, f, default_flow_style=False) + if not mto.ModeloptStateManager.is_converted(model[0]): return if len(model) > 1: diff --git a/modelopt/torch/utils/plugins/megatron_preprocess_data.py b/modelopt/torch/utils/plugins/megatron_preprocess_data.py index ddc20aff..ac05e44f 100644 --- a/modelopt/torch/utils/plugins/megatron_preprocess_data.py +++ b/modelopt/torch/utils/plugins/megatron_preprocess_data.py @@ -31,11 +31,14 @@ ``` """ +import argparse import json import multiprocessing import sys from pathlib import Path +import requests +from datasets import load_dataset from megatron.core.datasets import indexed_dataset from transformers import AutoTokenizer @@ -198,3 +201,92 @@ def megatron_preprocess_data( final_enc_len += num_tokens print(f">>> Total number of tokens: {final_enc_len}") + + +def main(): + """Sample main function to process large data for pretraining. + + Example usage: + + >>> python megatron_preprocess_data.py \ + --dataset "nvidia/Nemotron-Pretraining-Dataset-sample" \ + --tokenizer "meta-llama/Llama-3.2-1B-Instruct" \ + --output_dir "./processed_data" + """ + parser = argparse.ArgumentParser(prog="megatron_preprocess_data") + parser.add_argument("--input_path", type=str, default=None, help="Input path.") + parser.add_argument( + "--dataset", + type=str, + default="nvidia/Nemotron-Pretraining-Dataset-sample", + help="Hugging Face Hub dataset name or path", + ) + parser.add_argument("--subset", type=str, default=None, help="Hugging Face Hub dataset subset") + parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split") + parser.add_argument( + "--output_dir", type=str, default="./processed_data", help="Output directory" + ) + parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path") + parser.add_argument("--json_keys", nargs="+", default=["text"], help="JSON keys to tokenize") + parser.add_argument("--append_eod", action="store_true", help="Append token") + parser.add_argument( + "--max_sequence_length", type=int, default=None, help="Maximum sequence length" + ) + parser.add_argument("--workers", type=int, default=8, help="Number of worker processes") + parser.add_argument("--log_interval", type=int, default=1000, help="Log interval") + args = parser.parse_args() + + if args.input_path is None: + args.input_path = [] + + try: + response = requests.get( + f"https://datasets-server.huggingface.co/splits?dataset={args.dataset}", + timeout=10, + ) + response.raise_for_status() + except requests.RequestException as e: + print(f"Failed to fetch dataset splits for {args.dataset}: {e}") + return + + for entry in response.json()["splits"]: + skip_processing = False + name = entry["dataset"] + subset = entry.get("config", None) + split = entry["split"] + + if args.subset is not None and args.subset != subset: + skip_processing = True + if args.split is not None and args.split != split: + skip_processing = True + + print(f"Loading dataset {name} with subset {subset} and split {split}") + dataset = load_dataset(name, subset, split=split) + + for key in args.json_keys: + if key not in dataset.features: + print(f"Key {key} not found in dataset features. Skipping...") + skip_processing = True + break + + if skip_processing: + continue + + json_file_path = args.output_dir + "/" + name + "_" + subset + "_" + split + ".jsonl" + dataset.to_json(json_file_path) + args.input_path += [json_file_path] + + megatron_preprocess_data( + input_path=args.input_path, + output_dir=args.output_dir, + tokenizer_name_or_path=args.tokenizer, + json_keys=args.json_keys, + append_eod=args.append_eod, + max_sequence_length=args.max_sequence_length, + workers=args.workers, + log_interval=args.log_interval, + ) + + +if __name__ == "__main__": + main()