Skip to content

Commit adcb1a1

Browse files
authored
Allow passing in DistillationConfig directly to setup fn (#399)
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent baf55f2 commit adcb1a1

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

modelopt/torch/distill/plugins/megatron.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,26 +92,30 @@ def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable
9292
return student_layer, teacher_layer, loss_fn
9393

9494

95-
def load_distillation_config(
96-
config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig"
95+
def setup_distillation_config(
96+
config_or_path: str | DistillationConfig | None,
97+
student_cfg: "TransformerConfig",
98+
teacher_cfg: "TransformerConfig",
9799
) -> DistillationConfig:
98100
"""Read the distillation yaml config file specified by ``args.export_kd_cfg``.
99101
100102
Args:
101-
config_path: Path to user-defined distillation settings yaml file.
103+
config_or_path: Path to user-defined distillation settings yaml file, or the incomplete config itself.
102104
If `None`, uses default logits-only distillation mode for GPT models.
103105
student_cfg: Model config for student model.
104106
teacher_cfg: Model config for teacher model.
105107
106108
WARNING: Assumes intermediate hidden sizes are always that found in the model config's ``hidden_size`` attribute.
107109
"""
108-
if config_path:
109-
with open(config_path) as f:
110-
cfg = yaml.safe_load(f)
111-
cfg = DistillationConfig(**cfg)
112-
else:
110+
if config_or_path is None:
113111
logger.warning("Distillation config not provided. Using default.")
114112
cfg = DistillationConfig()
113+
elif isinstance(config_or_path, DistillationConfig):
114+
cfg = config_or_path
115+
else:
116+
with open(config_or_path) as f:
117+
cfg = yaml.safe_load(f)
118+
cfg = DistillationConfig(**cfg)
115119

116120
criterion = {}
117121
if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():

0 commit comments

Comments
 (0)