Skip to content

Commit 75ba7cd

Browse files
authored
Minor change to make KD MCore config bit more flexible (#442)
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 718fd9e commit 75ba7cd

File tree

1 file changed

+30
-27
lines changed

1 file changed

+30
-27
lines changed

modelopt/torch/distill/plugins/megatron.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,10 @@ def setup_distillation_config(
9797
student_cfg: "TransformerConfig",
9898
teacher_cfg: "TransformerConfig",
9999
) -> DistillationConfig:
100-
"""Read the distillation yaml config file specified by ``args.export_kd_cfg``.
100+
"""Setup and/or finalize the distillation config.
101+
102+
Either reads the distillation yaml config file from a path, fills in an
103+
existing DistillationConfig, or creates a default one.
101104
102105
Args:
103106
config_or_path: Path to user-defined distillation settings yaml file, or the incomplete config itself.
@@ -117,34 +120,34 @@ def setup_distillation_config(
117120
cfg = yaml.safe_load(f)
118121
cfg = DistillationConfig(**cfg)
119122

120-
criterion = {}
121-
if student_cfg.pipeline_model_parallel_size == 1 or parallel_state.is_pipeline_last_stage():
122-
criterion[tuple(cfg.logit_layers)] = LogitsKLLoss(
123-
student_cfg, temperature=cfg.logit_kl_temperature
124-
)
125-
# NOTE: Projection layer shared among intermediate layer pairs.
126-
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)
127-
128-
for entry in cfg.intermediate_layer_pairs:
129-
student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry)
130-
if parallel_state.get_tensor_and_context_parallel_rank() == 0:
131-
logger.info(
132-
"Distillation: Adding intermediate loss between"
133-
f" `{student_layer}` of student (hidden size {student_cfg.hidden_size}) and"
134-
f" `{teacher_layer}` of teacher (hidden size {teacher_cfg.hidden_size})."
135-
)
136-
student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg)
137-
teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg)
138-
criterion[(student_layer, teacher_layer)] = loss_fn(
139-
student_cfg, projection_layer=projection_layer
123+
if cfg.criterion is None:
124+
criterion = {}
125+
if parallel_state.is_pipeline_last_stage():
126+
criterion[tuple(cfg.logit_layers)] = LogitsKLLoss(
127+
student_cfg, temperature=cfg.logit_kl_temperature
140128
)
129+
# NOTE: Projection layer shared among intermediate layer pairs.
130+
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)
131+
132+
for entry in cfg.intermediate_layer_pairs:
133+
student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry)
134+
if parallel_state.get_tensor_and_context_parallel_rank() == 0:
135+
logger.info(
136+
"Distillation: Adding intermediate loss between"
137+
f" `{student_layer}` of student (hidden size {student_cfg.hidden_size}) and"
138+
f" `{teacher_layer}` of teacher (hidden size {teacher_cfg.hidden_size})."
139+
)
140+
student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg)
141+
teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg)
142+
criterion[(student_layer, teacher_layer)] = loss_fn(
143+
student_cfg, projection_layer=projection_layer
144+
)
145+
cfg.criterion = criterion
141146

142-
loss_balancer = LogitsAndIntermediatesLossBalancer(
143-
kd_loss_scale=cfg.kd_loss_scale, skip_original_loss=cfg.skip_lm_loss
144-
)
145-
146-
cfg.criterion = criterion
147-
cfg.loss_balancer = loss_balancer
147+
if cfg.loss_balancer is None:
148+
cfg.loss_balancer = LogitsAndIntermediatesLossBalancer(
149+
kd_loss_scale=cfg.kd_loss_scale, skip_original_loss=cfg.skip_lm_loss
150+
)
148151

149152
return cfg
150153

0 commit comments

Comments
 (0)