Skip to content

Commit 0b202b9

Browse files
committed
Merge branch 'main' into jingyux/megatron-lora
2 parents 0b310fb + adcb1a1 commit 0b202b9

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

modelopt/torch/distill/plugins/megatron.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,30 @@ def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable
9292
return student_layer, teacher_layer, loss_fn
9393

9494

95-
def load_distillation_config(
96-
config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig"
95+
def setup_distillation_config(
96+
config_or_path: str | DistillationConfig | None,
97+
student_cfg: "TransformerConfig",
98+
teacher_cfg: "TransformerConfig",
9799
) -> DistillationConfig:
98100
"""Read the distillation yaml config file specified by ``args.export_kd_cfg``.
99101
100102
Args:
101-
config_path: Path to user-defined distillation settings yaml file.
103+
config_or_path: Path to user-defined distillation settings yaml file, or the incomplete config itself.
102104
If `None`, uses default logits-only distillation mode for GPT models.
103105
student_cfg: Model config for student model.
104106
teacher_cfg: Model config for teacher model.
105107
106108
WARNING: Assumes intermediate hidden sizes are always that found in the model config's ``hidden_size`` attribute.
107109
"""
108-
if config_path:
109-
with open(config_path) as f:
110-
cfg = yaml.safe_load(f)
111-
cfg = DistillationConfig(**cfg)
112-
else:
110+
if config_or_path is None:
113111
logger.warning("Distillation config not provided. Using default.")
114112
cfg = DistillationConfig()
113+
elif isinstance(config_or_path, DistillationConfig):
114+
cfg = config_or_path
115+
else:
116+
with open(config_or_path) as f:
117+
cfg = yaml.safe_load(f)
118+
cfg = DistillationConfig(**cfg)
115119

116120
criterion = {}
117121
if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():

modelopt/torch/export/unified_export_hf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ def _export_quantized_weight(
332332

333333
setattr(sub_module, weight_name, nn.Parameter(quantized_weight, requires_grad=False))
334334

335+
# Register the corrected weight_scale as a buffer
336+
if weight_scale is not None:
337+
sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale)
338+
335339

336340
def _export_hf_checkpoint(
337341
model: nn.Module, dtype: torch.dtype | None = None

0 commit comments

Comments
 (0)