@@ -92,26 +92,30 @@ def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable
9292 return student_layer , teacher_layer , loss_fn
9393
9494
95- def load_distillation_config (
96- config_path : str | None , student_cfg : "TransformerConfig" , teacher_cfg : "TransformerConfig"
95+ def setup_distillation_config (
96+ config_or_path : str | DistillationConfig | None ,
97+ student_cfg : "TransformerConfig" ,
98+ teacher_cfg : "TransformerConfig" ,
9799) -> DistillationConfig :
98100 """Read the distillation yaml config file specified by ``args.export_kd_cfg``.
99101
100102 Args:
101- config_path : Path to user-defined distillation settings yaml file.
103+ config_or_path : Path to user-defined distillation settings yaml file, or the incomplete config itself .
102104 If `None`, uses default logits-only distillation mode for GPT models.
103105 student_cfg: Model config for student model.
104106 teacher_cfg: Model config for teacher model.
105107
106108 WARNING: Assumes intermediate hidden sizes are always that found in the model config's ``hidden_size`` attribute.
107109 """
108- if config_path :
109- with open (config_path ) as f :
110- cfg = yaml .safe_load (f )
111- cfg = DistillationConfig (** cfg )
112- else :
110+ if config_or_path is None :
113111 logger .warning ("Distillation config not provided. Using default." )
114112 cfg = DistillationConfig ()
113+ elif isinstance (config_or_path , DistillationConfig ):
114+ cfg = config_or_path
115+ else :
116+ with open (config_or_path ) as f :
117+ cfg = yaml .safe_load (f )
118+ cfg = DistillationConfig (** cfg )
115119
116120 criterion = {}
117121 if student_cfg .pipeline_model_parallel_size == 1 or parallel_state .is_pipeline_last_stage ():
0 commit comments