Skip to content

Commit 549d20f

Browse files
committed
Review suggestions and fix test
Signed-off-by: Asha Anoosheh <[email protected]>
1 parent 80c10a2 commit 549d20f

File tree

6 files changed

+51
-54
lines changed

6 files changed

+51
-54
lines changed

examples/llm_distill/README.md

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -154,47 +154,33 @@ Keep in mind the training loss of the distillation run is not directly comparabl
154154
### Train teacher
155155

156156
```bash
157-
accelerate launch \
158-
--multi_gpu \
159-
--mixed_precision bf16 \
160-
--fsdp_version 2 \
161-
--fsdp_reshard_after_forward True \
162-
--fsdp_auto_wrap_policy 'TRANSFORMER_BASED_WRAP' \
163-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
164-
\
157+
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
165158
main.py \
166159
--single_model \
167160
--teacher_name_or_path 'meta-llama/Llama-2-7b-hf' \
168161
--output_dir ./llama2-7b-sft \
169-
--logging_steps 5 \
170-
--max_steps 400 \
171162
--max_length 2048 \
172163
--per_device_train_batch_size 1 \
173164
--per_device_eval_batch_size 4 \
174-
--gradient_checkpointing True
165+
--max_steps 400 \
166+
--logging_steps 5
175167
```
176168

177169
### Distill teacher into student
178170

179171
```bash
180-
accelerate launch \
181-
--multi_gpu \
182-
--mixed_precision bf16 \
183-
--fsdp_version 2 \
184-
--fsdp_reshard_after_forward True \
185-
--fsdp_auto_wrap_policy 'TRANSFORMER_BASED_WRAP' \
186-
--fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer \
187-
\
172+
accelerate launch --config-file ./accelerate_config/fsdp2.yaml \
173+
--fsdp_cpu_ram_efficient_loading False \
174+
--fsdp_activation_checkpointing False \
188175
main.py \
189176
--teacher_name_or_path ./llama2-7b-sft \
190177
--student_name_or_path 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T' \
191178
--output_dir ./llama2-distill \
192-
--logging_steps 5 \
193-
--max_steps 200 \
194179
--max_length 2048 \
195180
--per_device_train_batch_size 1 \
196181
--per_device_eval_batch_size 4 \
197-
--gradient_checkpointing False
182+
--max_steps 200 \
183+
--logging_steps 5
198184
```
199185

200186
> [!NOTE]
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
compute_environment: LOCAL_MACHINE
2+
debug: false
3+
distributed_type: FSDP
4+
downcast_bf16: 'no'
5+
enable_cpu_affinity: false
6+
fsdp_config:
7+
fsdp_activation_checkpointing: true
8+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9+
fsdp_cpu_ram_efficient_loading: true
10+
fsdp_offload_params: false
11+
fsdp_reshard_after_forward: true
12+
fsdp_state_dict_type: SHARDED_STATE_DICT
13+
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
14+
fsdp_version: 2
15+
machine_rank: 0
16+
main_training_function: main
17+
mixed_precision: bf16
18+
num_machines: 1
19+
num_processes: gpu
20+
rdzv_backend: static
21+
same_network: true
22+
tpu_env: []
23+
tpu_use_cluster: false
24+
tpu_use_sudo: false
25+
use_cpu: false

examples/llm_distill/main.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import torch
2222
import torch.distributed
2323
import transformers
24-
from accelerate import PartialState
2524
from accelerate.logging import get_logger
2625
from transformers import AutoTokenizer
2726
from trl import SFTTrainer
@@ -108,21 +107,19 @@ def train():
108107
if model_args.single_model:
109108
logger.info("Loading single model only...")
110109
model = transformers.AutoModelForCausalLM.from_pretrained(
111-
model_path, device_map=PartialState().process_index
110+
model_path, dtype=torch.bfloat16 if training_args.bf16 else None
112111
)
113112
logger.info("Model loaded.")
114113
else:
115114
logger.info("Loading student model...")
116115
model = transformers.AutoModelForCausalLM.from_pretrained(
117-
model_args.student_name_or_path,
118-
device_map=PartialState().process_index,
116+
model_args.student_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
119117
)
120118
logger.info("Student loaded.")
121119
# Load checkpoint
122120
logger.info("Loading teacher model and converting to Distillation model...")
123121
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
124-
model_args.teacher_name_or_path,
125-
device_map=PartialState().process_index,
122+
model_args.teacher_name_or_path, dtype=torch.bfloat16 if training_args.bf16 else None
126123
)
127124
kd_config = {
128125
"teacher_model": teacher_model,
@@ -134,8 +131,6 @@ def train():
134131
# Fix problematic settings that logger.info excessive warnings
135132
model.generation_config.temperature = None
136133
model.generation_config.top_p = None
137-
if training_args.gradient_checkpointing:
138-
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
139134

140135
# Trainer
141136
trainer_cls = SFTTrainer if model_args.single_model else KDSFTTrainer
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
pyarrow
2-
trl==0.23.0
2+
trl>=0.23.0

tests/examples/llm_distill/test_llm_distill.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@
2121
def test_llama_distill(tiny_llama_path, tmp_path):
2222
run_example_command(
2323
[
24-
"accelerate", "launch", "--multi_gpu", "--mixed_precision", "bf16", "main.py",
24+
"accelerate", "launch", "--config-file", "./accelerate_config/fsdp2.yaml",
25+
"--fsdp_cpu_ram_efficient_loading", "False",
26+
"--fsdp_activation_checkpointing", "False",
27+
"main.py",
2528
"--teacher_name_or_path", tiny_llama_path,
2629
"--student_name_or_path", tiny_llama_path,
2730
"--output_dir", tmp_path,
28-
"--logging_steps", "5",
29-
"--max_steps", "10",
30-
"--max_seq_length", "1024",
31+
"--max_length", "1024",
3132
"--per_device_train_batch_size", "2",
3233
"--per_device_eval_batch_size", "8",
33-
"--gradient_checkpointing", "True",
34-
"--fsdp", "full_shard auto_wrap",
35-
"--fsdp_transformer_layer_cls_to_wrap", "LlamaDecoderLayer",
34+
"--max_steps", "10",
35+
"--logging_steps", "5",
3636
],
3737
"llm_distill",
3838
)

tests/unit/torch/opt/plugins/test_hf_patching.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,6 @@
2525
import modelopt.torch.opt as mto
2626

2727

28-
def _teacher_factory(model_name_or_path, teacher_model_type):
29-
if teacher_model_type == "qwen3":
30-
return get_tiny_qwen3()
31-
else:
32-
return AutoModelForCausalLM.from_pretrained(
33-
model_name_or_path,
34-
)
35-
36-
3728
@pytest.mark.parametrize(
3829
("model_cls", "teacher_model_type"),
3930
[
@@ -46,12 +37,13 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type):
4637

4738
model_ref = model_cls.from_pretrained(tiny_llama_dir)
4839

40+
if teacher_model_type == "qwen3":
41+
teacher_model = get_tiny_qwen3()
42+
else:
43+
teacher_model = AutoModelForCausalLM.from_pretrained(tiny_llama_dir)
44+
4945
kd_config = {
50-
"teacher_model": (
51-
_teacher_factory,
52-
(tiny_llama_dir, teacher_model_type),
53-
{},
54-
),
46+
"teacher_model": teacher_model,
5547
"criterion": mtd.LogitsDistillationLoss(),
5648
"expose_minimal_state_dict": False,
5749
}
@@ -61,6 +53,5 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type):
6153
model_test = model_cls.from_pretrained(tiny_llama_dir / "modelopt_model")
6254

6355
tf_output_tester(model, model_test)
64-
# since distill model contains loss function, we compare state of model and teacher model manually
56+
# since distill model contains loss function, we compare state of model manually
6557
assert mto.modelopt_state(model.model) == mto.modelopt_state(model_test.model)
66-
assert mto.modelopt_state(model._teacher_model) == mto.modelopt_state(model_test._teacher_model)

0 commit comments

Comments
 (0)