diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index c1fa45f6b..7db0a6e14 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -92,26 +92,30 @@ def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable return student_layer, teacher_layer, loss_fn -def load_distillation_config( - config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig" +def setup_distillation_config( + config_or_path: str | DistillationConfig | None, + student_cfg: "TransformerConfig", + teacher_cfg: "TransformerConfig", ) -> DistillationConfig: """Read the distillation yaml config file specified by ``args.export_kd_cfg``. Args: - config_path: Path to user-defined distillation settings yaml file. + config_or_path: Path to user-defined distillation settings yaml file, or the incomplete config itself. If `None`, uses default logits-only distillation mode for GPT models. student_cfg: Model config for student model. teacher_cfg: Model config for teacher model. WARNING: Assumes intermediate hidden sizes are always that found in the model config's ``hidden_size`` attribute. """ - if config_path: - with open(config_path) as f: - cfg = yaml.safe_load(f) - cfg = DistillationConfig(**cfg) - else: + if config_or_path is None: logger.warning("Distillation config not provided. Using default.") cfg = DistillationConfig() + elif isinstance(config_or_path, DistillationConfig): + cfg = config_or_path + else: + with open(config_or_path) as f: + cfg = yaml.safe_load(f) + cfg = DistillationConfig(**cfg) criterion = {} if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():