Skip to content

Commit 9e7c1a4

Browse files
committed
fix: feedback
Signed-off-by: Chenhan Yu <[email protected]>
1 parent e1f1bbb commit 9e7c1a4

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

modelopt/torch/opt/plugins/mcore_dist_checkpointing.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@
3636

3737
SUPPORTED_WRAPPERS[Float16Module] = "module"
3838

39+
DROP_SUBSTRINGS = [
40+
"fp4",
41+
"fp8",
42+
"tp_",
43+
"parallel",
44+
"cuda_graph",
45+
"init_",
46+
"cpu",
47+
"recompute",
48+
"inference",
49+
"pipeline",
50+
"comm",
51+
"batch",
52+
]
53+
3954

4055
def remove_per_module_state(
4156
modelopt_state: dict[str, Any],
@@ -126,18 +141,15 @@ def save_sharded_modelopt_state(
126141

127142
def _parse_transformer_config(transformer_config: dict) -> dict:
128143
config = {}
144+
129145
for k, v in transformer_config.items():
146+
if any(substring in k for substring in DROP_SUBSTRINGS):
147+
continue
130148
if isinstance(v, (bool, int, str)):
131149
config[k] = v
132150
else:
133151
config[k] = str(v)
134-
config = {k: v for k, v in config.items() if "fp4" not in k and "fp8" not in k}
135-
config = {k: v for k, v in config.items() if "tp_" not in k and "parallel" not in k}
136-
config = {k: v for k, v in config.items() if "cuda_graph" not in k}
137-
config = {k: v for k, v in config.items() if "init_" not in k and "cpu" not in k}
138-
config = {k: v for k, v in config.items() if "recompute" not in k and "inference" not in k}
139-
config = {k: v for k, v in config.items() if "pipeline" not in k and "comm" not in k}
140-
config = {k: v for k, v in config.items() if "batch" not in k}
152+
141153
return config
142154

143155
if dist.is_master():

modelopt/torch/utils/plugins/megatron_preprocess_data.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -210,13 +210,16 @@ def main():
210210
211211
>>> python megatron_preprocess_data.py \
212212
--dataset "nvidia/Nemotron-Pretraining-Dataset-sample" \
213-
--tokenizer "nvidia/Nemotron-Pretraining-Tokenizer" \
213+
--tokenizer "meta-llama/Llama-3.2-1B-Instruct" \
214214
--output_dir "./processed_data"
215215
"""
216216
parser = argparse.ArgumentParser(prog="megatron_preprocess_data")
217217
parser.add_argument("--input_path", type=str, default=None, help="Input path.")
218218
parser.add_argument(
219-
"--dataset", type=str, default=None, help="Hugging Face Hub dataset name or path"
219+
"--dataset",
220+
type=str,
221+
default="nvidia/Nemotron-Pretraining-Dataset-sample",
222+
help="Hugging Face Hub dataset name or path",
220223
)
221224
parser.add_argument("--subset", type=str, default=None, help="Hugging Face Hub dataset subset")
222225
parser.add_argument("--split", type=str, default="train", help="Hugging Face Hub dataset split")
@@ -225,7 +228,7 @@ def main():
225228
)
226229
parser.add_argument("--tokenizer", type=str, required=True, help="Tokenizer name or path")
227230
parser.add_argument("--json_keys", nargs="+", default=["text"], help="JSON keys to tokenize")
228-
parser.add_argument("--append_eod", type=bool, default=False, help="Append <eod> token")
231+
parser.add_argument("--append_eod", action="store_true", help="Append <eod> token")
229232
parser.add_argument(
230233
"--max_sequence_length", type=int, default=None, help="Maximum sequence length"
231234
)
@@ -235,8 +238,6 @@ def main():
235238

236239
if args.input_path is None:
237240
args.input_path = []
238-
if args.dataset is None:
239-
args.dataset = "nvidia/Nemotron-Pretraining-Dataset-sample"
240241

241242
response = requests.get(
242243
"https://datasets-server.huggingface.co/splits?dataset={}".format(args.dataset),
@@ -250,9 +251,9 @@ def main():
250251
split = entry["split"]
251252

252253
if args.subset is not None and args.subset != subset:
253-
continue
254+
skip_processing = True
254255
if args.split is not None and args.split != split:
255-
continue
256+
skip_processing = True
256257

257258
print(f"Loading dataset {name} with subset {subset} and split {split}")
258259
dataset = load_dataset(name, subset, split=split)

0 commit comments

Comments
 (0)