26
26
mto .enable_huggingface_checkpointing ()
27
27
28
28
# Hyperparameters for profiling
29
- EPOCHS = 1
30
- LOG_INTERVAL = 1
29
+ EPOCHS = 10
30
+ LOG_INTERVAL = 100
31
31
SAVE_INTERVAL = 20000
32
32
# VALIDATE_INTERVAL = 20
33
33
@@ -125,6 +125,7 @@ def _recv_from_teacher(self):
125
125
req .wait ()
126
126
127
127
def _get_distill_kwargs (self ):
128
+ """Return a copy of received buffer for student training."""
128
129
return {k : v .clone ().detach () for k , v in self .student_recv_buffer .items ()}
129
130
130
131
def _send_to_student (self , teacher_outputs ):
@@ -141,25 +142,6 @@ def _send_to_student(self, teacher_outputs):
141
142
for req in reqs :
142
143
req .wait ()
143
144
144
- # def _validate_ar(self, steps=3, osl=20, num_samples=20):
145
- # if self.rank != self.args.student_rank:
146
- # return
147
- # # Load MT-Bench prompts from HuggingFace
148
- # ds = load_dataset("HuggingFaceH4/mt_bench_prompts")["train"]
149
- # self.model.eval()
150
- # self.model.to(self.args.student_device)
151
- # ars = validate_ar(
152
- # self.model, self.tokenizer, ds, steps, osl, num_samples, self.args.student_device
153
- # )
154
- # # Print results
155
- # avg_ar = sum(ars) / len(ars)
156
- # print("\n==== AR Validation Results on MT-Bench ====")
157
- # print(f"Number of samples: {len(ars)}")
158
- # print(f"Output Sequence Length: {osl}")
159
- # print(f"Steps: {steps}")
160
- # print(f"Average AR: {avg_ar:.4f}")
161
- # self.model.train()
162
-
163
145
def train (self , dataloader ):
164
146
"""Main training entrance of the composed model."""
165
147
self ._reset_all_mem_stats ()
@@ -174,19 +156,24 @@ def train(self, dataloader):
174
156
project = os .environ ["WANDB_PROJECT" ],
175
157
config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
176
158
) as run :
177
- self .model , self .optimizer = self .load_student_model ()
159
+ self .model , self .optimizer , self . scheduler = self .load_student_model ()
178
160
self ._init_student_recv_buffer ()
179
161
wandb .watch (self .model , log = "all" )
180
162
181
163
for epoch in range (EPOCHS ):
182
- pbar = tqdm (dataloader )
164
+ pbar = (
165
+ tqdm (dataloader ) if self .rank == self .args .student_ranks [0 ] else dataloader
166
+ )
183
167
for i , batch in enumerate (pbar ):
184
168
global_step = epoch * len (dataloader ) + i
185
169
inputs = {k : v .to (self .model .device ) for k , v in batch .items ()}
186
170
self ._recv_from_teacher ()
187
171
loss , train_acc = self .student_step (inputs , ** self ._get_distill_kwargs ())
188
- pbar .set_description (f"Epoch { epoch } Loss:{ loss } Acc:{ train_acc } " )
189
172
173
+ if self .rank != self .args .student_ranks [0 ]:
174
+ continue
175
+
176
+ pbar .set_description (f"Epoch { epoch } Loss:{ loss } Acc:{ train_acc } " )
190
177
if global_step % LOG_INTERVAL == 0 :
191
178
run .log (
192
179
{
@@ -195,14 +182,10 @@ def train(self, dataloader):
195
182
"train_acc_step1" : train_acc [1 ],
196
183
"train_acc_step2" : train_acc [2 ],
197
184
"train_acc_step3" : train_acc [3 ],
185
+ "lr" : self .optimizer .param_groups [0 ]["lr" ],
198
186
},
199
187
step = global_step ,
200
188
)
201
-
202
- # This is not working for some reason.
203
- # if global_step > 0 and global_step % VALIDATE_INTERVAL == 0:
204
- # self._validate_ar()
205
-
206
189
if global_step > 0 and global_step % SAVE_INTERVAL == 0 :
207
190
self .save_pretrained (
208
191
f"{ self .args .out_path } /epoch_{ epoch } _step_{ global_step } "
0 commit comments