-
Notifications
You must be signed in to change notification settings - Fork 162
Adding modelopt_run_config.yaml and a main function for megatron data preprocessing #341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
from typing import Any | ||
|
||
import torch | ||
import yaml | ||
from megatron.core import dist_checkpointing, mpu | ||
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy | ||
from megatron.core.dist_checkpointing.strategies.common import COMMON_STATE_FNAME | ||
|
@@ -122,6 +123,30 @@ def save_sharded_modelopt_state( | |
sharded_strategy: configures sharded tensors saving behavior and backend | ||
prefix: the prefix to add to the modelopt_state keys ("model." for NeMo) | ||
""" | ||
|
||
def _parse_transformer_config(transformer_config: dict) -> dict: | ||
config = {} | ||
for k, v in transformer_config.items(): | ||
if isinstance(v, (bool, int, str)): | ||
config[k] = v | ||
else: | ||
config[k] = str(v) | ||
config = {k: v for k, v in config.items() if "fp4" not in k and "fp8" not in k} | ||
config = {k: v for k, v in config.items() if "tp_" not in k and "parallel" not in k} | ||
config = {k: v for k, v in config.items() if "cuda_graph" not in k} | ||
config = {k: v for k, v in config.items() if "init_" not in k and "cpu" not in k} | ||
config = {k: v for k, v in config.items() if "recompute" not in k and "inference" not in k} | ||
config = {k: v for k, v in config.items() if "pipeline" not in k and "comm" not in k} | ||
config = {k: v for k, v in config.items() if "batch" not in k} | ||
return config | ||
|
||
if dist.is_master(): | ||
run_config_name = f"{checkpoint_name}/modelopt_run_config.yaml" | ||
config_dict = _parse_transformer_config(copy.deepcopy(model[0].config.__dict__)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ChenhanYu I'm seeing an issue with the deepcopy these fields are set on the config during training: These are bound methods to DDP class instances which have pytorch process groups initialized as member variables. the deepcopy here is going into deepcopying the class itself, causing issues like this:
in megatron bridge, we have a yaml representer for this to serialize our configs to yaml: so we don't see this when serializing our overall job config to yaml (which includes the transformer config): There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. The deepcopy here is likely not necessary. Let me remove it. |
||
config_dict["nvidia_modelopt_version"] = modelopt.__version__ | ||
with open(run_config_name, "w") as f: | ||
yaml.dump(config_dict, f, default_flow_style=False) | ||
|
||
ChenhanYu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not mto.ModeloptStateManager.is_converted(model[0]): | ||
return | ||
if len(model) > 1: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -31,11 +31,14 @@ | |||||||||||||||||||
``` | ||||||||||||||||||||
""" | ||||||||||||||||||||
|
||||||||||||||||||||
import argparse | ||||||||||||||||||||
import json | ||||||||||||||||||||
import multiprocessing | ||||||||||||||||||||
import sys | ||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||
|
||||||||||||||||||||
import requests | ||||||||||||||||||||
from datasets import load_dataset | ||||||||||||||||||||
Comment on lines
+40
to
+41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainEnsure packaging lists new dependencies. requests and datasets are new runtime deps for the CLI. Verify they’re declared (pyproject/requirements/extras). Also applies to: 206-289 🏁 Script executed: #!/bin/bash
# Verify packaging declares 'datasets' and 'requests'
rg -nP -C2 '^(datasets|requests)\b' -g 'pyproject.toml' -g 'requirements*.txt' -g 'setup.cfg' -g 'setup.py' Length of output: 1129 Add 'requests' and 'datasets' to project packaging/runtime deps modelopt/torch/utils/plugins/megatron_preprocess_data.py imports requests and datasets (lines 40–41; also referenced ~206–289). Search shows only 'datasets' in examples/*/requirements.txt and no top-level declaration for either. Add both to pyproject.toml (or setup.cfg / setup.py) or the CLI extras so the CLI runtime installs them. 🤖 Prompt for AI Agents
|
||||||||||||||||||||
from megatron.core.datasets import indexed_dataset | ||||||||||||||||||||
from transformers import AutoTokenizer | ||||||||||||||||||||
|
||||||||||||||||||||
|
@@ -198,3 +201,86 @@ def megatron_preprocess_data( | |||||||||||||||||||
final_enc_len += num_tokens | ||||||||||||||||||||
|
||||||||||||||||||||
print(f">>> Total number of tokens: {final_enc_len}") | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
def main(): | ||||||||||||||||||||
"""Sample main function to process large data for pretraining. | ||||||||||||||||||||
|
||||||||||||||||||||
Example usage: | ||||||||||||||||||||
|
||||||||||||||||||||
>>> python megatron_preprocess_data.py \ | ||||||||||||||||||||
--dataset "nvidia/Nemotron-Pretraining-Dataset-sample" \ | ||||||||||||||||||||
--tokenizer "nvidia/Nemotron-Pretraining-Tokenizer" \ | ||||||||||||||||||||
--output_dir "./processed_data" | ||||||||||||||||||||
""" | ||||||||||||||||||||
parser = argparse.ArgumentParser(prog="megatron_preprocess_data") | ||||||||||||||||||||
parser.add_argument("--input_path", type=str, default=None, help="Input path.") | ||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||
"--dataset", type=str, default=None, help="Hugging Face Hub dataset name or path" | ||||||||||||||||||||
) | ||||||||||||||||||||
parser.add_argument("--subset", type=str, default=None, help="Hugging Face Hub dataset subset") | ||||||||||||||||||||
parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split") | ||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||
Comment on lines
+225
to
+226
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix default --split: make it None to process all splits when omitted Current default "train" contradicts the PR description (“all subsets and splits if subset/split omitted”) and filters to train-only. Set default to None and keep the existing filter logic. Apply this diff: - parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split")
+ parser.add_argument(
+ "--split",
+ type=str,
+ default=None,
+ help="Hugging Face Hub dataset split (process all if omitted)",
+ ) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||
"--output_dir", type=str, default="./processed_data", help="Output directory" | ||||||||||||||||||||
) | ||||||||||||||||||||
parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path") | ||||||||||||||||||||
parser.add_argument("--json_keys", nargs="+", default=["text"], help="JSON keys to tokenize") | ||||||||||||||||||||
parser.add_argument("--append_eod", type=bool, default=False, help="Append <eod> token") | ||||||||||||||||||||
parser.add_argument( | ||||||||||||||||||||
"--max_sequence_length", type=int, default=None, help="Maximum sequence length" | ||||||||||||||||||||
) | ||||||||||||||||||||
parser.add_argument("--workers", type=int, default=8, help="Number of worker processes") | ||||||||||||||||||||
parser.add_argument("--log_interval", type=int, default=1000, help="Log interval") | ||||||||||||||||||||
args = parser.parse_args() | ||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
|
||||||||||||||||||||
if args.input_path is None: | ||||||||||||||||||||
args.input_path = [] | ||||||||||||||||||||
ChenhanYu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
if args.dataset is None: | ||||||||||||||||||||
args.dataset = "nvidia/Nemotron-Pretraining-Dataset-sample" | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
response = requests.get( | ||||||||||||||||||||
"https://datasets-server.huggingface.co/splits?dataset={}".format(args.dataset), | ||||||||||||||||||||
timeout=10, | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
for entry in response.json()["splits"]: | ||||||||||||||||||||
skip_processing = False | ||||||||||||||||||||
ChenhanYu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||||||||||||||||
name = entry["dataset"] | ||||||||||||||||||||
subset = entry.get("config", None) | ||||||||||||||||||||
split = entry["split"] | ||||||||||||||||||||
|
||||||||||||||||||||
if args.subset is not None and args.subset != subset: | ||||||||||||||||||||
continue | ||||||||||||||||||||
|
||||||||||||||||||||
if args.split is not None and args.split != split: | ||||||||||||||||||||
continue | ||||||||||||||||||||
|
||||||||||||||||||||
print(f"Loading dataset {name} with subset {subset} and split {split}") | ||||||||||||||||||||
dataset = load_dataset(name, subset, split=split) | ||||||||||||||||||||
|
||||||||||||||||||||
for key in args.json_keys: | ||||||||||||||||||||
if key not in dataset.features: | ||||||||||||||||||||
print(f"Key {key} not found in dataset features. Skipping...") | ||||||||||||||||||||
skip_processing = True | ||||||||||||||||||||
break | ||||||||||||||||||||
|
||||||||||||||||||||
if skip_processing: | ||||||||||||||||||||
continue | ||||||||||||||||||||
|
||||||||||||||||||||
json_file_path = args.output_dir + "/" + name + "_" + subset + "_" + split + ".jsonl" | ||||||||||||||||||||
dataset.to_json(json_file_path) | ||||||||||||||||||||
args.input_path += [json_file_path] | ||||||||||||||||||||
ChenhanYu marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||
|
||||||||||||||||||||
megatron_preprocess_data( | ||||||||||||||||||||
input_path=args.input_path, | ||||||||||||||||||||
output_dir=args.output_dir, | ||||||||||||||||||||
tokenizer_name_or_path=args.tokenizer, | ||||||||||||||||||||
json_keys=args.json_keys, | ||||||||||||||||||||
append_eod=args.append_eod, | ||||||||||||||||||||
max_sequence_length=args.max_sequence_length, | ||||||||||||||||||||
workers=args.workers, | ||||||||||||||||||||
log_interval=args.log_interval, | ||||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
|
||||||||||||||||||||
if __name__ == "__main__": | ||||||||||||||||||||
main() |
Uh oh!
There was an error while loading. Please reload this page.