Skip to content

Commit cf07926

Browse files
committed
Disable KD from saving any real state
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 682bf6d commit cf07926

File tree

5 files changed

+24
-45
lines changed

5 files changed

+24
-45
lines changed

modelopt/torch/distill/config.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,18 @@
1616
"""Configurations for distillation modes."""
1717

1818
import warnings
19-
from collections.abc import Callable
2019
from typing import Any, Union
2120

2221
import pydantic
23-
import torch.nn as nn
2422
from torch.nn.modules.loss import _Loss as Loss
2523

2624
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
25+
from modelopt.torch.utils.network import ModelLike
2726

2827
from .loss_balancers import DistillationLossBalancer
2928

3029
__all__ = ["KDLossConfig"]
3130

32-
TeacherModel = type[nn.Module] | tuple | Callable
3331
Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007
3432

3533

@@ -42,14 +40,13 @@ class KDLossConfig(ModeloptBaseConfig):
4240
# TODO: we should really think about a better to configure KDLossConfig
4341
model_config = pydantic.ConfigDict(extra="forbid", arbitrary_types_allowed=True)
4442

45-
teacher_model: TeacherModel | None = ModeloptField(
43+
teacher_model: ModelLike | None = ModeloptField(
4644
default=None,
4745
title="Teacher model",
4846
description=(
49-
"The class or callable or tuple to initialize the teacher model using"
47+
"The module, class, callable, or tuple to initialize the teacher model using"
5048
" :meth:`init_model_from_model_like"
51-
" <modelopt.torch.utils.network.init_model_from_model_like>`. This cannot already be an"
52-
" instance of nn.Module."
49+
" <modelopt.torch.utils.network.init_model_from_model_like>`."
5350
),
5451
)
5552
criterion: Criterion | None = ModeloptField(

modelopt/torch/distill/mode.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ def restore(self) -> RestoreEntrypoint:
8383
@property
8484
def update_for_new_mode(self) -> UpdateEntrypoint:
8585
"""The mode's entrypoint for updating the models state for adding new mode."""
86-
return _update_kd_state_before_new_mode
86+
return _reset_kd_state_config
87+
88+
@property
89+
def update_for_save(self) -> UpdateEntrypoint:
90+
"""The mode's entrypoint for updating the models state before saving."""
91+
return _reset_kd_state_config
8792

8893

8994
@DistillModeRegistry.register_mode
@@ -171,16 +176,12 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType
171176

172177
def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module:
173178
"""Function for restoring a previously convert model to a distillation meta-model."""
174-
# the metadata should be empty
175-
assert not metadata, "No metadata expected!"
179+
# NOTE: DistillationModel will purposely remain unrestored
180+
return model
176181

177-
return _convert_for_kd(model, config)[0]
178182

179-
180-
def _update_kd_state_before_new_mode(
181-
model: nn.Module, config: KDLossConfig, metadata: MetadataDict
182-
) -> None:
183-
"""Function for updating the model's state before new mode."""
183+
def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
184+
"""Function for resetting the state's config."""
184185
config.teacher_model = nn.Module
185186
config.criterion = Loss()
186187
config.loss_balancer = None
@@ -216,8 +217,5 @@ def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertRet
216217
def _restore_exported_student(
217218
model: nn.Module, config: ExportStudentConfig, metadata: MetadataDict
218219
) -> nn.Module:
219-
"""Function for restoring a previously exported distillation meta-model."""
220-
# no metadata is used by the mode
221-
assert not metadata, "No metadata expected!"
222-
223-
return _export_student(model, config)[0]
220+
# NOTE: DistillationModel was unrestored so this does nothing
221+
return model

modelopt/torch/distill/plugins/huggingface.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,6 @@ def save_model(
8080
state_dict=state_dict,
8181
)
8282
self.processing_class.save_pretrained(output_dir)
83-
if export_student:
84-
modelopt_state["modelopt_state_dict"] = [
85-
state
86-
for state in modelopt_state["modelopt_state_dict"]
87-
if "kd_loss" not in state and "export_student" not in state
88-
]
8983
torch.save(modelopt_state, f"{output_dir}/modelopt_state.pth")
9084
else:
9185
model = model.export() if export_student else model

modelopt/torch/quantization/plugins/transformers_trainer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,6 @@ def _save_modelopt_state_with_weights(self):
174174
torch.distributed.barrier()
175175

176176
modelopt_state = mto.modelopt_state(self.model)
177-
# TODO: remove this from ModelOpt HF Trainer flows
178-
modelopt_state["modelopt_state_dict"] = [
179-
state
180-
for state in modelopt_state["modelopt_state_dict"]
181-
if "kd_loss" not in state and "export_student" not in state
182-
]
183177
modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model)
184178

185179
if self.args.should_save:

tests/unit/torch/distill/test_distill.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,15 @@ def test_distillation_save_restore(distillation_model, tmp_path):
147147
new_student = tiny_mobilenet()
148148
distillation_model_new = mto.restore(new_student, tmp_path / "ckpt.pt")
149149

150-
assert isinstance(distillation_model_new, mtd.DistillationModel)
151-
assert distillation_model_new.teacher_model is not None
152-
153-
input = get_input_tensor()
154-
155-
# disable dropout for deterministic results
156-
distillation_model.eval()
157-
distillation_model_new.eval()
158-
159-
out = distillation_model(input)
160-
out_new = distillation_model_new(input)
150+
# Ensure state config was reset
151+
manager = mto.ModeloptStateManager(distillation_model_new)
152+
cfg = manager._state[-1][1]["config"]
153+
assert cfg["teacher_model"] == nn.Module
154+
assert isinstance(next(iter(cfg["criterion"].values())), Loss)
155+
assert cfg["loss_balancer"] is None
161156

162-
assert torch.allclose(out, out_new)
157+
# Should not have restored anything
158+
assert isinstance(distillation_model_new, type(new_student))
163159

164160

165161
def test_distillation_export(distillation_model, tmp_path):

0 commit comments

Comments
 (0)