This repository contains an implementation of Iterative Layer-wise Distillation, a structured approach for distilling LLMs by ranking and removing transformer layers based on their contribution to downstream performance. The approach is inspired by ShortGPT (2024).
The method iteratively prunes layers and fine-tunes the resulting student model using a diverse set of benchmarks covering reasoning, summarization, translation, and generation tasks.
Layer importance is calculated by evalutaing model without target layer on seven datasets from the LLMTF benchmark:
Clone this repository:
git clone https://github.com/kaengreg/layer-wise_distillation.git
cd layer-wise_distillationconda env create -f environment.yml
conda activate layerwise-distillationpip install -r requirements.txt
⚠️ Make sure to installtorchand CUDA-specific dependencies manually as needed for your setup.
python3 Qwen2Distillation.py \
--student_model_path $STUDENT \
--distil_layers $TRAIN_LAYERS \
--removed_layers_iterations $PRUNE_LAYERS_1 \
--removed_layers_iterations $PRUNE_LAYERS_2 \
--removed_layers_iterations $PRUNE_LAYERS_3 \
--removed_layers_iterations $PRUNE_LAYERS_4 \
--learning_rate $LR \
--num_train_epochs $EPOCHS \
--per_device_train_batch_size $BS \
--per_device_eval_batch_size $BS \
--gradient_accumulation_steps $GRADACM \
--maxlen $MAXLEN \
--ds_frac $DSFRAC \
--use_local_data true \
--norm_factor $NORMFACT \
--output_dir $OUTPUT_DIR \
--logging_dir $LOGGING_DIRUse the following scripts for multi-node, multi-GPU training on SLURM:
Replace
python3withtorchrunfor distributed training.
| Argument | Description |
|---|---|
--teacher_model_path |
Path to the full teacher model (default: "Qwen/Qwen2.5-3B"). |
--student_model_path |
Path to the student model (default: "kngrg/Qwen2.5-3B-trimmed2"). |
--distil_layers |
List of layer indices to retain in the student model (e.g., --distil_layers 0 1 2 3). |
--removed_layers_iterations |
List(s) of layer indices to be removed at each distillation iteration (e.g., --removed_layers_iterations 4 5 --removed_layers_iterations 6 7). |
--train_dataset |
Name or path of the Hugging Face dataset to use (default: "kngrg/ru-miracl-cleaned"). |
--learning_rate |
Learning rate used for optimization (default: 1e-4). |
--num_train_epochs |
Number of epochs for training per iteration (default: 1). |
--per_device_train_batch_size |
Batch size per GPU for training (default: 16). |
--per_device_eval_batch_size |
Batch size per GPU for evaluation (default: 16). |
--gradient_accumulation_steps |
Number of forward-backward passes before one optimizer step (default: 16). |
--eval_steps |
Evaluate the model every N steps (default: 50). |
--save_steps |
Save a checkpoint every N steps (default: 512). |
--logging_steps |
Log training metrics every N steps (default: 1). |
--warmup_steps |
Linear warm-up over this many steps (default: 8). |
--max_grad_norm |
Gradient clipping threshold (default: 0.3). |
--weight_decay |
Weight decay coefficient for regularization (default: 0.05). |
--bf16 |
Enable bfloat16 mixed precision training (default: True). |
--fp16 |
Enable float16 mixed precision training (default: False). |
--maxlen |
Maximum token sequence length (default: 512). |
--ds_frac |
Number of training samples to use per epoch (default: 3). |
--use_local_data |
Whether to load dataset from local disk (true or false, default: false). |
--norm_factor |
Normalization factor applied to the components of the loss function (default: 0.1). |
--output_dir |
Directory to save the distilled model checkpoints (default: ./qwen2.5-3b-trimmed2-logits+hs). |
--logging_dir |
Directory to store training logs (default: ./logs-logits+hs). |
- Qwen2.5-3B → Qwen2.5-2B-layerwise-distilled
A reduced-size version of the Qwen2.5-3B model, preserving most of its performance with fewer parameters.
[TODO] Add a detailed report covering the methodology, pruning strategy, evaluation metrics, and final benchmark results.
