@@ -92,26 +92,30 @@ def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable
92
92
return student_layer , teacher_layer , loss_fn
93
93
94
94
95
- def load_distillation_config (
96
- config_path : str | None , student_cfg : "TransformerConfig" , teacher_cfg : "TransformerConfig"
95
+ def setup_distillation_config (
96
+ config_or_path : str | DistillationConfig | None ,
97
+ student_cfg : "TransformerConfig" ,
98
+ teacher_cfg : "TransformerConfig" ,
97
99
) -> DistillationConfig :
98
100
"""Read the distillation yaml config file specified by ``args.export_kd_cfg``.
99
101
100
102
Args:
101
- config_path : Path to user-defined distillation settings yaml file.
103
+ config_or_path : Path to user-defined distillation settings yaml file, or the incomplete config itself .
102
104
If `None`, uses default logits-only distillation mode for GPT models.
103
105
student_cfg: Model config for student model.
104
106
teacher_cfg: Model config for teacher model.
105
107
106
108
WARNING: Assumes intermediate hidden sizes are always that found in the model config's ``hidden_size`` attribute.
107
109
"""
108
- if config_path :
109
- with open (config_path ) as f :
110
- cfg = yaml .safe_load (f )
111
- cfg = DistillationConfig (** cfg )
112
- else :
110
+ if config_or_path is None :
113
111
logger .warning ("Distillation config not provided. Using default." )
114
112
cfg = DistillationConfig ()
113
+ elif isinstance (config_or_path , DistillationConfig ):
114
+ cfg = config_or_path
115
+ else :
116
+ with open (config_or_path ) as f :
117
+ cfg = yaml .safe_load (f )
118
+ cfg = DistillationConfig (** cfg )
115
119
116
120
criterion = {}
117
121
if student_cfg .pipeline_model_parallel_size == 1 or parallel_state .is_pipeline_last_stage ():
0 commit comments