Skip to content

Commit e1f1bbb

Browse files
committed
new: adding modelopt_run_config.yaml and a main function for data
preprocessing Signed-off-by: Chenhan Yu <[email protected]>
1 parent bbb2304 commit e1f1bbb

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

modelopt/torch/opt/plugins/mcore_dist_checkpointing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any
2323

2424
import torch
25+
import yaml
2526
from megatron.core import dist_checkpointing, mpu
2627
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
2728
from megatron.core.dist_checkpointing.strategies.common import COMMON_STATE_FNAME
@@ -122,6 +123,30 @@ def save_sharded_modelopt_state(
122123
sharded_strategy: configures sharded tensors saving behavior and backend
123124
prefix: the prefix to add to the modelopt_state keys ("model." for NeMo)
124125
"""
126+
127+
def _parse_transformer_config(transformer_config: dict) -> dict:
128+
config = {}
129+
for k, v in transformer_config.items():
130+
if isinstance(v, (bool, int, str)):
131+
config[k] = v
132+
else:
133+
config[k] = str(v)
134+
config = {k: v for k, v in config.items() if "fp4" not in k and "fp8" not in k}
135+
config = {k: v for k, v in config.items() if "tp_" not in k and "parallel" not in k}
136+
config = {k: v for k, v in config.items() if "cuda_graph" not in k}
137+
config = {k: v for k, v in config.items() if "init_" not in k and "cpu" not in k}
138+
config = {k: v for k, v in config.items() if "recompute" not in k and "inference" not in k}
139+
config = {k: v for k, v in config.items() if "pipeline" not in k and "comm" not in k}
140+
config = {k: v for k, v in config.items() if "batch" not in k}
141+
return config
142+
143+
if dist.is_master():
144+
run_config_name = f"{checkpoint_name}/modelopt_run_config.yaml"
145+
config_dict = _parse_transformer_config(copy.deepcopy(model[0].config.__dict__))
146+
config_dict["nvidia_modelopt_version"] = modelopt.__version__
147+
with open(run_config_name, "w") as f:
148+
yaml.dump(config_dict, f, default_flow_style=False)
149+
125150
if not mto.ModeloptStateManager.is_converted(model[0]):
126151
return
127152
if len(model) > 1:

modelopt/torch/utils/plugins/megatron_preprocess_data.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,14 @@
3131
```
3232
"""
3333

34+
import argparse
3435
import json
3536
import multiprocessing
3637
import sys
3738
from pathlib import Path
3839

40+
import requests
41+
from datasets import load_dataset
3942
from megatron.core.datasets import indexed_dataset
4043
from transformers import AutoTokenizer
4144

@@ -198,3 +201,86 @@ def megatron_preprocess_data(
198201
final_enc_len += num_tokens
199202

200203
print(f">>> Total number of tokens: {final_enc_len}")
204+
205+
206+
def main():
207+
"""Sample main function to process large data for pretraining.
208+
209+
Example usage:
210+
211+
>>> python megatron_preprocess_data.py \
212+
--dataset "nvidia/Nemotron-Pretraining-Dataset-sample" \
213+
--tokenizer "nvidia/Nemotron-Pretraining-Tokenizer" \
214+
--output_dir "./processed_data"
215+
"""
216+
parser = argparse.ArgumentParser(prog="megatron_preprocess_data")
217+
parser.add_argument("--input_path", type=str, default=None, help="Input path.")
218+
parser.add_argument(
219+
"--dataset", type=str, default=None, help="Hugging Face Hub dataset name or path"
220+
)
221+
parser.add_argument("--subset", type=str, default=None, help="Hugging Face Hub dataset subset")
222+
parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split")
223+
parser.add_argument(
224+
"--output_dir", type=str, default="./processed_data", help="Output directory"
225+
)
226+
parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path")
227+
parser.add_argument("--json_keys", nargs="+", default=["text"], help="JSON keys to tokenize")
228+
parser.add_argument("--append_eod", type=bool, default=False, help="Append <eod> token")
229+
parser.add_argument(
230+
"--max_sequence_length", type=int, default=None, help="Maximum sequence length"
231+
)
232+
parser.add_argument("--workers", type=int, default=8, help="Number of worker processes")
233+
parser.add_argument("--log_interval", type=int, default=1000, help="Log interval")
234+
args = parser.parse_args()
235+
236+
if args.input_path is None:
237+
args.input_path = []
238+
if args.dataset is None:
239+
args.dataset = "nvidia/Nemotron-Pretraining-Dataset-sample"
240+
241+
response = requests.get(
242+
"https://datasets-server.huggingface.co/splits?dataset={}".format(args.dataset),
243+
timeout=10,
244+
)
245+
246+
for entry in response.json()["splits"]:
247+
skip_processing = False
248+
name = entry["dataset"]
249+
subset = entry.get("config", None)
250+
split = entry["split"]
251+
252+
if args.subset is not None and args.subset != subset:
253+
continue
254+
if args.split is not None and args.split != split:
255+
continue
256+
257+
print(f"Loading dataset {name} with subset {subset} and split {split}")
258+
dataset = load_dataset(name, subset, split=split)
259+
260+
for key in args.json_keys:
261+
if key not in dataset.features:
262+
print(f"Key {key} not found in dataset features. Skipping...")
263+
skip_processing = True
264+
break
265+
266+
if skip_processing:
267+
continue
268+
269+
json_file_path = args.output_dir + "/" + name + "_" + subset + "_" + split + ".jsonl"
270+
dataset.to_json(json_file_path)
271+
args.input_path += [json_file_path]
272+
273+
megatron_preprocess_data(
274+
input_path=args.input_path,
275+
output_dir=args.output_dir,
276+
tokenizer_name_or_path=args.tokenizer,
277+
json_keys=args.json_keys,
278+
append_eod=args.append_eod,
279+
max_sequence_length=args.max_sequence_length,
280+
workers=args.workers,
281+
log_interval=args.log_interval,
282+
)
283+
284+
285+
if __name__ == "__main__":
286+
main()

0 commit comments

Comments
 (0)