Skip to content

Commit 3ee4832

Browse files
Fix params_dtype for distillation and GPT HF Exporter head_dim for pruning (#12792) (#13002)
* Fix GPT HF Exporter dtype and head_dim * Fix params_dtype * Apply isort and black reformatting --------- Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Signed-off-by: kevalmorabia97 <kevalmorabia97@users.noreply.github.com> Co-authored-by: kevalmorabia97 <kevalmorabia97@users.noreply.github.com>
1 parent 2ab0aeb commit 3ee4832

File tree

7 files changed

+36
-7
lines changed

7 files changed

+36
-7
lines changed

nemo/collections/llm/gpt/model/gemma.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ def config(self) -> "GemmaConfig":
284284
hidden_size=source.hidden_size,
285285
intermediate_size=source.ffn_hidden_size,
286286
num_attention_heads=source.num_attention_heads,
287+
head_dim=(
288+
source.kv_channels
289+
if source.kv_channels is not None
290+
else source.hidden_size // source.num_attention_heads
291+
),
287292
max_position_embeddings=source.seq_length,
288293
initializer_range=source.init_method_std,
289294
rms_norm_eps=source.layernorm_epsilon,

nemo/collections/llm/gpt/model/gemma2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,11 @@ def config(self) -> "Gemma2Config":
366366
hidden_size=source.hidden_size,
367367
intermediate_size=source.ffn_hidden_size,
368368
num_attention_heads=source.num_attention_heads,
369+
head_dim=(
370+
source.kv_channels
371+
if source.kv_channels is not None
372+
else source.hidden_size // source.num_attention_heads
373+
),
369374
max_position_embeddings=source.seq_length,
370375
initializer_range=source.init_method_std,
371376
rms_norm_eps=source.layernorm_epsilon,

nemo/collections/llm/gpt/model/llama.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,11 @@ def config(self) -> "HFLlamaConfig":
704704
hidden_size=source.hidden_size,
705705
intermediate_size=source.ffn_hidden_size,
706706
num_attention_heads=source.num_attention_heads,
707+
head_dim=(
708+
source.kv_channels
709+
if source.kv_channels is not None
710+
else source.hidden_size // source.num_attention_heads
711+
),
707712
max_position_embeddings=source.seq_length,
708713
initializer_range=source.init_method_std,
709714
rms_norm_eps=source.layernorm_epsilon,

nemo/collections/llm/gpt/model/qwen2.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,11 @@ def config(self) -> "HFQwen2Config":
378378
hidden_size=source.hidden_size,
379379
intermediate_size=source.ffn_hidden_size,
380380
num_attention_heads=source.num_attention_heads,
381+
head_dim=(
382+
source.kv_channels
383+
if source.kv_channels is not None
384+
else source.hidden_size // source.num_attention_heads
385+
),
381386
max_position_embeddings=source.seq_length,
382387
initializer_range=source.init_method_std,
383388
rms_norm_eps=source.layernorm_epsilon,

nemo/collections/llm/modelopt/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def setup_trainer_and_restore_model_with_modelopt_spec(
134134
num_nodes=num_nodes,
135135
accelerator="gpu",
136136
strategy=strategy,
137-
plugins=nl.MegatronMixedPrecision(precision="bf16", params_dtype=torch.bfloat16, autocast_enabled=True),
137+
plugins=nl.MegatronMixedPrecision(
138+
precision="bf16-mixed", params_dtype=torch.bfloat16, autocast_enabled=False, grad_reduce_in_fp32=True
139+
),
138140
**trainer_kwargs,
139141
)
140142

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -902,11 +902,12 @@ def save_checkpoint(
902902
ckpt_io = self.checkpoint_io
903903
if isinstance(ckpt_io, _WrappingCheckpointIO):
904904
ckpt_io = ckpt_io.checkpoint_io
905-
mto.plugins.save_sharded_modelopt_state(
906-
[core_model],
907-
ckpt_to_weights_subdir(filepath, is_saving=True),
908-
sharded_strategy=ckpt_io.save_sharded_strategy,
909-
)
905+
with core_model.hide_teacher_model() if hasattr(core_model, "hide_teacher_model") else nullcontext():
906+
mto.plugins.save_sharded_modelopt_state(
907+
[core_model],
908+
ckpt_to_weights_subdir(filepath, is_saving=True),
909+
sharded_strategy=ckpt_io.save_sharded_strategy,
910+
)
910911
logging.info("Saved Model-Optimizer state into checkpoint.")
911912

912913
def should_restore_optimizer_states(self, selective_restore: bool = False) -> bool:

scripts/llm/gpt_distillation.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
from argparse import ArgumentParser
1717

18+
import torch
1819
from lightning.pytorch.loggers import TensorBoardLogger
1920
from megatron.core.dist_checkpointing.validation import StrictHandling
2021
from megatron.core.optimizer import OptimizerConfig
@@ -82,7 +83,12 @@ def get_args():
8283
limit_val_batches=args.limit_val_batches,
8384
strategy=strategy,
8485
accelerator="gpu",
85-
plugins=nl.MegatronMixedPrecision(precision=args.precision),
86+
plugins=nl.MegatronMixedPrecision(
87+
precision=args.precision,
88+
params_dtype=torch.bfloat16 if "bf16" in args.precision else torch.float32,
89+
autocast_enabled=False,
90+
grad_reduce_in_fp32=True,
91+
),
8692
)
8793

8894
# Set up dataset

0 commit comments

Comments
 (0)