diff --git a/docs/source/guides/4_distillation.rst b/docs/source/guides/4_distillation.rst index 4bf228ec..79f13673 100644 --- a/docs/source/guides/4_distillation.rst +++ b/docs/source/guides/4_distillation.rst @@ -16,9 +16,9 @@ a more powerful teacher model using :mod:`modelopt.torch.distill ` and - restore via :meth:`mto.restore `. See :ref:`saving and restoring ` - to learn more. +#. **Checkpoint and re-load**: Save the model via :meth:`mto.save ` + Note that restoring the model (via :meth:`mto.restore `) + will not reinstantiate the distillation meta-model, in order to avoid unpickling issues. *To find out more about Distillation and related concepts, please refer to the below section* :ref:`Distillation Concepts `. @@ -44,7 +44,7 @@ Example usage: # Configure and convert for distillation distillation_config = { - # `teacher_model` is a model class or callable, or a tuple. + # `teacher_model` is a model, model class, callable, or a tuple. # If a tuple, it must be of the form (model_cls_or_callable,) or # (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs). "teacher_model": teacher_model, @@ -53,15 +53,9 @@ Example usage: } distillation_model = mtd.convert(model, mode=[("kd_loss", distillation_config)]) - # Export model in original class form + # Export model in original class, with only previously-present attributes model_exported = mtd.export(distillation_model) -.. note:: - The config requires a (non-lambda) Callable to return a teacher model in place of the model - itself. This is to avoid re-saving the teacher state dict upon saving the Distillation - meta model. Thus, the same callable must be available in the namespace when restoring via - the :meth:`mto.restore ` utility. - .. tip:: When training the student on a small corpus of ground truth data, consider using :class:`MFTLoss ` for to perform Minifinetuning in lieu of the standard :class:`LogitsDistillationLoss `. This will allow the student to learn from the teacher's distribution while adapting to the new data, improving the specialization of the new data without overwriting teacher's general knowledge. @@ -170,10 +164,12 @@ outputs in the same order as well: The intermediate outputs for the losses are captured by the :class:`DistillationModel ` and then the loss(es) are invoked using :meth:`DistillationModel.compute_kd_loss() `. -If present, the original student's non-distillation loss is passed in as an argument. +If present, the original student's non-distillation loss can be passed in as an argument. Writing a custom loss function is often necessary, especially to handle outputs that need to be processed -to obtain the logits and activations. +to obtain the logits and activations. Additional arguments to the loss function can be passed in to +:meth:`DistillationModel.compute_kd_loss() ` +as ``kwargs``. Loss Balancer ^^^^^^^^^^^^^ diff --git a/examples/llm_distill/README.md b/examples/llm_distill/README.md index c8085265..6b97e8f7 100644 --- a/examples/llm_distill/README.md +++ b/examples/llm_distill/README.md @@ -39,13 +39,9 @@ First obtain both a pretrained model to act as the teacher and a (usually smalle ```python from transformers import AutoModelForCausalLM -# Define student +# Define student & teacher student_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct") - -# Define callable which returns teacher -def teacher_factory(): - teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct") - return teacher_model +teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B-Instruct") ``` ### Set up the meta model @@ -58,7 +54,7 @@ Please see an example Distillation setup below. This example assumes the outputs import modelopt.torch.distill as mtd distillation_config = { - "teacher_model": teacher_factory, # model initializer + "teacher_model": teacher_model, "criterion": mtd.LogitsDistillationLoss(), # callable receiving student and teacher outputs, in order "loss_balancer": mtd.StaticLossBalancer(), # combines multiple losses; omit if only one distillation loss used } @@ -66,7 +62,7 @@ distillation_config = { distillation_model = mtd.convert(student_model, mode=[("kd_loss", distillation_config)]) ``` -The `teacher_model` can be either a callable which returns an `nn.Module` or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed). +The `teacher_model` can be either a `nn.Module`, a callable which returns an `nn.Module`, or a tuple of `(model_cls, args, kwargs)`. The `criterion` is the distillation loss used between student and teacher tensors. The `loss_balancer` determines how the original and distillation losses are combined (if needed). See [Distillation](https://nvidia.github.io/TensorRT-Model-Optimizer/guides/4_distillation.html) for more info. @@ -158,35 +154,33 @@ Keep in mind the training loss of the distillation run is not directly comparabl ### Train teacher ```bash -accelerate launch --multi_gpu --mixed_precision bf16 main.py \ +accelerate launch --config-file ./accelerate_config/fsdp2.yaml \ + main.py \ --single_model \ --teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \ --output_dir ./llama2-7b-sft \ - --logging_steps 5 \ - --max_steps 400 \ - --max_seq_length 2048 \ + --max_length 2048 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 4 \ - --gradient_checkpointing True \ - --fsdp 'full_shard auto_wrap' \ - --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer + --max_steps 400 \ + --logging_steps 5 ``` ### Distill teacher into student ```bash -accelerate launch --multi_gpu --mixed_precision bf16 main.py \ +accelerate launch --config-file ./accelerate_config/fsdp2.yaml \ + --fsdp_cpu_ram_efficient_loading False \ + --fsdp_activation_checkpointing False \ + main.py \ --teacher_name_or_path ./llama2-7b-sft \ --student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \ --output_dir ./llama2-distill \ - --logging_steps 5 \ - --max_steps 200 \ - --max_seq_length 2048 \ + --max_length 2048 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 4 \ - --gradient_checkpointing False \ - --fsdp 'full_shard auto_wrap' \ - --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer + --max_steps 200 \ + --logging_steps 5 ``` > [!NOTE] diff --git a/examples/llm_distill/accelerate_config/fsdp2.yaml b/examples/llm_distill/accelerate_config/fsdp2.yaml new file mode 100644 index 00000000..3c901d61 --- /dev/null +++ b/examples/llm_distill/accelerate_config/fsdp2.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: gpu +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/llm_distill/main.py b/examples/llm_distill/main.py index fa142c62..efbf36e8 100644 --- a/examples/llm_distill/main.py +++ b/examples/llm_distill/main.py @@ -21,7 +21,6 @@ import torch import torch.distributed import transformers -from accelerate import PartialState from accelerate.logging import get_logger from transformers import AutoTokenizer from trl import SFTTrainer @@ -48,38 +47,28 @@ class TrainingArguments(transformers.TrainingArguments): do_train: bool = True do_eval: bool = True save_strategy: str = "no" - max_seq_length: int = 1024 + max_length: int = 1024 optim: str = "adamw_torch" learning_rate: float = 1e-5 lr_scheduler_type: str = "cosine" dataloader_drop_last: bool = True dataset_num_proc: int = 8 - dataset_batch_size: int = 500 bf16: bool = True tf32: bool = True def llama_text_format_func(sample): - texts = [] - for p, q, r in zip(sample["system_prompt"], sample["question"], sample["response"]): - if not p: - texts.append(f"[INST] {q}[/INST]\n{r}") - else: - texts.append(f"[INST] <>{p}<>\n{q}[/INST]\n{r}") - return texts + p, q, r = sample["system_prompt"], sample["question"], sample["response"] + if not p: + return f"[INST] {q}[/INST]\n{r}" + else: + return f"[INST] <>{p}<>\n{q}[/INST]\n{r}" class KDSFTTrainer(SFTTrainer, KDTrainer): pass -def _teacher_factory(model_name_or_path): - return transformers.AutoModelForCausalLM.from_pretrained( - model_name_or_path, - device_map=PartialState().process_index, - ) - - def train(): parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) model_args, training_args = parser.parse_args_into_dataclasses() @@ -117,25 +106,24 @@ def train(): if model_args.single_model: logger.info("Loading single model only...") - model = _teacher_factory(model_path) + model = transformers.AutoModelForCausalLM.from_pretrained( + model_path, dtype=torch.bfloat16 if training_args.bf16 else None + ) logger.info("Model loaded.") else: logger.info("Loading student model...") model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.student_name_or_path, - device_map=PartialState().process_index, + model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None ) logger.info("Student loaded.") # Load checkpoint logger.info("Loading teacher model and converting to Distillation model...") + teacher_model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None + ) kd_config = { - "teacher_model": ( - _teacher_factory, - (model_args.teacher_name_or_path,), - {}, - ), + "teacher_model": teacher_model, "criterion": LMLogitsLoss(), - "expose_minimal_state_dict": False, # FSDP forces us to disable this } model = mtd.convert(model, mode=[("kd_loss", kd_config)]) logger.info("Models converted.") @@ -143,8 +131,6 @@ def train(): # Fix problematic settings that logger.info excessive warnings model.generation_config.temperature = None model.generation_config.top_p = None - if training_args.gradient_checkpointing: - training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} # Trainer trainer_cls = SFTTrainer if model_args.single_model else KDSFTTrainer diff --git a/examples/llm_distill/requirements.txt b/examples/llm_distill/requirements.txt index 11f4b709..b67db72d 100644 --- a/examples/llm_distill/requirements.txt +++ b/examples/llm_distill/requirements.txt @@ -1,2 +1,2 @@ pyarrow -trl==0.13.0 +trl>=0.23.0 diff --git a/modelopt/torch/distill/config.py b/modelopt/torch/distill/config.py index 5dcd00af..cfdb3ccb 100644 --- a/modelopt/torch/distill/config.py +++ b/modelopt/torch/distill/config.py @@ -16,20 +16,18 @@ """Configurations for distillation modes.""" import warnings -from collections.abc import Callable from typing import Any, Union import pydantic -import torch.nn as nn from torch.nn.modules.loss import _Loss as Loss from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField +from modelopt.torch.utils.network import ModelLike from .loss_balancers import DistillationLossBalancer __all__ = ["KDLossConfig"] -TeacherModel = type[nn.Module] | tuple | Callable Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007 @@ -42,14 +40,13 @@ class KDLossConfig(ModeloptBaseConfig): # TODO: we should really think about a better to configure KDLossConfig model_config = pydantic.ConfigDict(extra="forbid", arbitrary_types_allowed=True) - teacher_model: TeacherModel | None = ModeloptField( + teacher_model: ModelLike | None = ModeloptField( default=None, title="Teacher model", description=( - "The class or callable or tuple to initialize the teacher model using" + "The module, class, callable, or tuple to initialize the teacher model using" " :meth:`init_model_from_model_like" - " `. This cannot already be an" - " instance of nn.Module." + " `." ), ) criterion: Criterion | None = ModeloptField( diff --git a/modelopt/torch/distill/mode.py b/modelopt/torch/distill/mode.py index 52cc478f..b65a31cd 100644 --- a/modelopt/torch/distill/mode.py +++ b/modelopt/torch/distill/mode.py @@ -83,7 +83,12 @@ def restore(self) -> RestoreEntrypoint: @property def update_for_new_mode(self) -> UpdateEntrypoint: """The mode's entrypoint for updating the models state for adding new mode.""" - return _update_kd_state_before_new_mode + return _reset_kd_state_config + + @property + def update_for_save(self) -> UpdateEntrypoint: + """The mode's entrypoint for updating the models state before saving.""" + return _reset_kd_state_config @DistillModeRegistry.register_mode @@ -171,16 +176,12 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType def _restore_kd_model(model: nn.Module, config: KDLossConfig, metadata: MetadataDict) -> nn.Module: """Function for restoring a previously convert model to a distillation meta-model.""" - # the metadata should be empty - assert not metadata, "No metadata expected!" + # NOTE: DistillationModel will purposely remain unrestored + return model - return _convert_for_kd(model, config)[0] - -def _update_kd_state_before_new_mode( - model: nn.Module, config: KDLossConfig, metadata: MetadataDict -) -> None: - """Function for updating the model's state before new mode.""" +def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict): + """Function for resetting the state's config.""" config.teacher_model = nn.Module config.criterion = Loss() config.loss_balancer = None @@ -216,8 +217,5 @@ def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertRet def _restore_exported_student( model: nn.Module, config: ExportStudentConfig, metadata: MetadataDict ) -> nn.Module: - """Function for restoring a previously exported distillation meta-model.""" - # no metadata is used by the mode - assert not metadata, "No metadata expected!" - - return _export_student(model, config)[0] + # NOTE: DistillationModel was unrestored so this does nothing + return model diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index 719b5b50..55b516a0 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -80,12 +80,6 @@ def save_model( state_dict=state_dict, ) self.processing_class.save_pretrained(output_dir) - if export_student: - modelopt_state["modelopt_state_dict"] = [ - state - for state in modelopt_state["modelopt_state_dict"] - if "kd_loss" not in state and "export_student" not in state - ] torch.save(modelopt_state, f"{output_dir}/modelopt_state.pth") else: model = model.export() if export_student else model diff --git a/modelopt/torch/quantization/plugins/transformers_trainer.py b/modelopt/torch/quantization/plugins/transformers_trainer.py index e6a0a2b7..9d429ffa 100644 --- a/modelopt/torch/quantization/plugins/transformers_trainer.py +++ b/modelopt/torch/quantization/plugins/transformers_trainer.py @@ -174,12 +174,6 @@ def _save_modelopt_state_with_weights(self): torch.distributed.barrier() modelopt_state = mto.modelopt_state(self.model) - # TODO: remove this from ModelOpt HF Trainer flows - modelopt_state["modelopt_state_dict"] = [ - state - for state in modelopt_state["modelopt_state_dict"] - if "kd_loss" not in state and "export_student" not in state - ] modelopt_state["modelopt_state_weights"] = get_quantizer_state_dict(self.model) if self.args.should_save: diff --git a/tests/examples/llm_distill/test_llm_distill.py b/tests/examples/llm_distill/test_llm_distill.py index dc885b45..8b1af7f7 100644 --- a/tests/examples/llm_distill/test_llm_distill.py +++ b/tests/examples/llm_distill/test_llm_distill.py @@ -21,18 +21,18 @@ def test_llama_distill(tiny_llama_path, tmp_path): run_example_command( [ - "accelerate", "launch", "--multi_gpu", "--mixed_precision", "bf16", "main.py", + "accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml", + "--fsdp_cpu_ram_efficient_loading", "False", + "--fsdp_activation_checkpointing", "False", + "main.py", "--teacher_name_or_path", tiny_llama_path, "--student_name_or_path", tiny_llama_path, "--output_dir", tmp_path, - "--logging_steps", "5", - "--max_steps", "10", - "--max_seq_length", "1024", + "--max_length", "1024", "--per_device_train_batch_size", "2", "--per_device_eval_batch_size", "8", - "--gradient_checkpointing", "True", - "--fsdp", "full_shard auto_wrap", - "--fsdp_transformer_layer_cls_to_wrap", "LlamaDecoderLayer", + "--max_steps", "10", + "--logging_steps", "5", ], "llm_distill", ) diff --git a/tests/unit/torch/distill/test_distill.py b/tests/unit/torch/distill/test_distill.py index c989f144..c2b06d3a 100644 --- a/tests/unit/torch/distill/test_distill.py +++ b/tests/unit/torch/distill/test_distill.py @@ -147,19 +147,15 @@ def test_distillation_save_restore(distillation_model, tmp_path): new_student = tiny_mobilenet() distillation_model_new = mto.restore(new_student, tmp_path / "ckpt.pt") - assert isinstance(distillation_model_new, mtd.DistillationModel) - assert distillation_model_new.teacher_model is not None - - input = get_input_tensor() - - # disable dropout for deterministic results - distillation_model.eval() - distillation_model_new.eval() - - out = distillation_model(input) - out_new = distillation_model_new(input) + # Ensure state config was reset + manager = mto.ModeloptStateManager(distillation_model_new) + cfg = manager._state[-1][1]["config"] + assert cfg["teacher_model"] == nn.Module + assert isinstance(next(iter(cfg["criterion"].values())), Loss) + assert cfg["loss_balancer"] is None - assert torch.allclose(out, out_new) + # Should not have restored anything + assert isinstance(distillation_model_new, type(new_student)) def test_distillation_export(distillation_model, tmp_path): diff --git a/tests/unit/torch/opt/plugins/test_hf_patching.py b/tests/unit/torch/opt/plugins/test_hf_patching.py index 4d795c8f..9122728e 100644 --- a/tests/unit/torch/opt/plugins/test_hf_patching.py +++ b/tests/unit/torch/opt/plugins/test_hf_patching.py @@ -25,15 +25,6 @@ import modelopt.torch.opt as mto -def _teacher_factory(model_name_or_path, teacher_model_type): - if teacher_model_type == "qwen3": - return get_tiny_qwen3() - else: - return AutoModelForCausalLM.from_pretrained( - model_name_or_path, - ) - - @pytest.mark.parametrize( ("model_cls", "teacher_model_type"), [ @@ -46,12 +37,13 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type): model_ref = model_cls.from_pretrained(tiny_llama_dir) + if teacher_model_type == "qwen3": + teacher_model = get_tiny_qwen3() + else: + teacher_model = AutoModelForCausalLM.from_pretrained(tiny_llama_dir) + kd_config = { - "teacher_model": ( - _teacher_factory, - (tiny_llama_dir, teacher_model_type), - {}, - ), + "teacher_model": teacher_model, "criterion": mtd.LogitsDistillationLoss(), "expose_minimal_state_dict": False, } @@ -61,6 +53,5 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type): model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model") tf_output_tester(model, model_test) - # since distill model contains loss function, we compare state of model and teacher model manually + # since distill model contains loss function, we compare state of model manually assert mto.modelopt_state(model.model) == mto.modelopt_state(model_test.model) - assert mto.modelopt_state(model._teacher_model) == mto.modelopt_state(model_test._teacher_model)