@@ -97,7 +97,10 @@ def setup_distillation_config(
9797 student_cfg : "TransformerConfig" ,
9898 teacher_cfg : "TransformerConfig" ,
9999) -> DistillationConfig :
100- """Read the distillation yaml config file specified by ``args.export_kd_cfg``.
100+ """Setup and/or finalize the distillation config.
101+
102+ Either reads the distillation yaml config file from a path, fills in an
103+ existing DistillationConfig, or creates a default one.
101104
102105 Args:
103106 config_or_path: Path to user-defined distillation settings yaml file, or the incomplete config itself.
@@ -117,34 +120,34 @@ def setup_distillation_config(
117120 cfg = yaml .safe_load (f )
118121 cfg = DistillationConfig (** cfg )
119122
120- criterion = {}
121- if student_cfg .pipeline_model_parallel_size == 1 or parallel_state .is_pipeline_last_stage ():
122- criterion [tuple (cfg .logit_layers )] = LogitsKLLoss (
123- student_cfg , temperature = cfg .logit_kl_temperature
124- )
125- # NOTE: Projection layer shared among intermediate layer pairs.
126- projection_layer = ProjectionLayer (student_cfg , teacher_cfg )
127-
128- for entry in cfg .intermediate_layer_pairs :
129- student_layer , teacher_layer , loss_fn = cfg .parse_intermediate_entry (entry )
130- if parallel_state .get_tensor_and_context_parallel_rank () == 0 :
131- logger .info (
132- "Distillation: Adding intermediate loss between"
133- f" `{ student_layer } ` of student (hidden size { student_cfg .hidden_size } ) and"
134- f" `{ teacher_layer } ` of teacher (hidden size { teacher_cfg .hidden_size } )."
135- )
136- student_layer = _adjust_layer_index_for_pp (student_layer , student_cfg )
137- teacher_layer = _adjust_layer_index_for_pp (teacher_layer , teacher_cfg )
138- criterion [(student_layer , teacher_layer )] = loss_fn (
139- student_cfg , projection_layer = projection_layer
123+ if cfg .criterion is None :
124+ criterion = {}
125+ if parallel_state .is_pipeline_last_stage ():
126+ criterion [tuple (cfg .logit_layers )] = LogitsKLLoss (
127+ student_cfg , temperature = cfg .logit_kl_temperature
140128 )
129+ # NOTE: Projection layer shared among intermediate layer pairs.
130+ projection_layer = ProjectionLayer (student_cfg , teacher_cfg )
131+
132+ for entry in cfg .intermediate_layer_pairs :
133+ student_layer , teacher_layer , loss_fn = cfg .parse_intermediate_entry (entry )
134+ if parallel_state .get_tensor_and_context_parallel_rank () == 0 :
135+ logger .info (
136+ "Distillation: Adding intermediate loss between"
137+ f" `{ student_layer } ` of student (hidden size { student_cfg .hidden_size } ) and"
138+ f" `{ teacher_layer } ` of teacher (hidden size { teacher_cfg .hidden_size } )."
139+ )
140+ student_layer = _adjust_layer_index_for_pp (student_layer , student_cfg )
141+ teacher_layer = _adjust_layer_index_for_pp (teacher_layer , teacher_cfg )
142+ criterion [(student_layer , teacher_layer )] = loss_fn (
143+ student_cfg , projection_layer = projection_layer
144+ )
145+ cfg .criterion = criterion
141146
142- loss_balancer = LogitsAndIntermediatesLossBalancer (
143- kd_loss_scale = cfg .kd_loss_scale , skip_original_loss = cfg .skip_lm_loss
144- )
145-
146- cfg .criterion = criterion
147- cfg .loss_balancer = loss_balancer
147+ if cfg .loss_balancer is None :
148+ cfg .loss_balancer = LogitsAndIntermediatesLossBalancer (
149+ kd_loss_scale = cfg .kd_loss_scale , skip_original_loss = cfg .skip_lm_loss
150+ )
148151
149152 return cfg
150153
0 commit comments