21
21
import torch
22
22
import torch .distributed
23
23
import transformers
24
- from accelerate import PartialState
25
24
from accelerate .logging import get_logger
26
25
from transformers import AutoTokenizer
27
26
from trl import SFTTrainer
@@ -108,21 +107,19 @@ def train():
108
107
if model_args .single_model :
109
108
logger .info ("Loading single model only..." )
110
109
model = transformers .AutoModelForCausalLM .from_pretrained (
111
- model_path , device_map = PartialState (). process_index
110
+ model_path , dtype = torch . bfloat16 if training_args . bf16 else None
112
111
)
113
112
logger .info ("Model loaded." )
114
113
else :
115
114
logger .info ("Loading student model..." )
116
115
model = transformers .AutoModelForCausalLM .from_pretrained (
117
- model_args .student_name_or_path ,
118
- device_map = PartialState ().process_index ,
116
+ model_args .student_name_or_path , dtype = torch .bfloat16 if training_args .bf16 else None
119
117
)
120
118
logger .info ("Student loaded." )
121
119
# Load checkpoint
122
120
logger .info ("Loading teacher model and converting to Distillation model..." )
123
121
teacher_model = transformers .AutoModelForCausalLM .from_pretrained (
124
- model_args .teacher_name_or_path ,
125
- device_map = PartialState ().process_index ,
122
+ model_args .teacher_name_or_path , dtype = torch .bfloat16 if training_args .bf16 else None
126
123
)
127
124
kd_config = {
128
125
"teacher_model" : teacher_model ,
@@ -134,8 +131,6 @@ def train():
134
131
# Fix problematic settings that logger.info excessive warnings
135
132
model .generation_config .temperature = None
136
133
model .generation_config .top_p = None
137
- if training_args .gradient_checkpointing :
138
- training_args .gradient_checkpointing_kwargs = {"use_reentrant" : False }
139
134
140
135
# Trainer
141
136
trainer_cls = SFTTrainer if model_args .single_model else KDSFTTrainer
0 commit comments