12
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
+ import json
15
16
import os
16
17
17
18
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
26
27
27
28
import modelopt .torch .opt as mto
28
29
import modelopt .torch .speculative as mtsp
30
+ from modelopt .torch .speculative .config import EAGLE3_DEFAULT_CFG
29
31
30
32
mto .enable_huggingface_checkpointing ()
31
33
32
34
# Hyperparameters for profiling
33
35
EPOCHS = 1
34
36
LOG_INTERVAL = 100
35
37
SAVE_INTERVAL = 20000
36
- MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37
- DRAFT_VOCAB_SIZE = 32000
38
38
# VALIDATE_INTERVAL = 20
39
39
40
40
# Shape and dtype description of the distillation signal
@@ -51,13 +51,21 @@ class BaseDistillTrainer:
51
51
student_step: student step function.
52
52
"""
53
53
54
- def __init__ (self , rank , args , tokenizer , distill_metadata : DistillMetadata ):
54
+ def __init__ (self , rank , args , tokenizer ):
55
55
self .rank = rank
56
56
args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
57
57
args .student_pgroup = dist .new_group (ranks = args .student_ranks )
58
58
self .args = args
59
59
self .tokenizer = tokenizer
60
- self .distill_metadata = distill_metadata
60
+ if rank in args .student_ranks :
61
+ self .model = self .prepare_student_model ()
62
+ self .optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .args .lr )
63
+ self .scheduler = get_linear_schedule_with_warmup (
64
+ self .optimizer , num_warmup_steps = 0 , num_training_steps = 117380
65
+ )
66
+ else :
67
+ self .model = self .prepare_teacher_model ()
68
+ self ._print_model_placement (self .model )
61
69
62
70
def _print_model_placement (self , module ):
63
71
for name , param in module .named_parameters ():
@@ -67,6 +75,10 @@ def _print_model_placement(self, module):
67
75
def current_rank_device (self ):
68
76
pass
69
77
78
+ @property
79
+ def distill_metadata (self ):
80
+ pass
81
+
70
82
def _reset_all_mem_stats (self ):
71
83
torch .cuda .reset_max_memory_allocated (self .current_rank_device )
72
84
@@ -162,7 +174,6 @@ def train(self, dataloader):
162
174
project = os .environ ["WANDB_PROJECT" ],
163
175
config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
164
176
) as run :
165
- self .model , self .optimizer , self .scheduler = self .load_student_model ()
166
177
self ._init_student_recv_buffer ()
167
178
wandb .watch (self .model , log = "all" )
168
179
@@ -198,7 +209,6 @@ def train(self, dataloader):
198
209
)
199
210
200
211
else :
201
- self .model = self .load_teacher_model ()
202
212
# Inference Loop
203
213
for epoch in range (EPOCHS ):
204
214
for i , batch in enumerate (dataloader ):
@@ -217,16 +227,60 @@ def train(self, dataloader):
217
227
218
228
219
229
class EagleTPTrainer (BaseDistillTrainer ):
230
+ def __init__ (self , rank , args , tokenizer ):
231
+ args .eagle_config = EAGLE3_DEFAULT_CFG ["config" ]
232
+ if args .eagle_config_path :
233
+ with open (args .eagle_config_path ) as f :
234
+ custom_config = json .load (f )
235
+ args .eagle_config ["eagle_architecture_config" ].update (custom_config )
236
+
237
+ super ().__init__ (rank , args , tokenizer )
238
+
220
239
@property
221
240
def current_rank_device (self ):
222
241
if self .rank in self .args .student_ranks :
223
242
return self .args .student_devices [self .rank ]
224
243
else :
225
244
return self .args .teacher_devices [self .rank - len (self .args .student_ranks )]
226
245
227
- def load_teacher_model (self ):
246
+ @property
247
+ def distill_metadata (self ) -> DistillMetadata :
248
+ return {
249
+ "base_model_hidden_states" : (
250
+ torch .Size (
251
+ [
252
+ int (self .args .batch_size / len (self .args .student_ranks )),
253
+ self .args .training_seq_len ,
254
+ 2048 ,
255
+ ]
256
+ ),
257
+ torch .bfloat16 ,
258
+ ),
259
+ "aux_hidden_states" : (
260
+ torch .Size (
261
+ [
262
+ int (self .args .batch_size / len (self .args .student_ranks )),
263
+ self .args .training_seq_len ,
264
+ 2048 * 3 ,
265
+ ]
266
+ ),
267
+ torch .bfloat16 ,
268
+ ),
269
+ "base_model_logits" : (
270
+ torch .Size (
271
+ [
272
+ int (self .args .batch_size / len (self .args .student_ranks )),
273
+ self .args .training_seq_len ,
274
+ self .args .draft_vocab_size ,
275
+ ]
276
+ ),
277
+ torch .bfloat16 ,
278
+ ),
279
+ }
280
+
281
+ def prepare_teacher_model (self ):
228
282
model = AutoModelForCausalLM .from_pretrained (
229
- MODEL_PATH ,
283
+ self . args . model_path ,
230
284
torch_dtype = "auto" ,
231
285
tp_plan = "auto" ,
232
286
device_mesh = DeviceMesh .from_group (self .args .teacher_pgroup , "cuda" ),
@@ -235,42 +289,33 @@ def load_teacher_model(self):
235
289
{
236
290
"hidden_size" : model .config .hidden_size ,
237
291
"vocab_size" : model .config .vocab_size ,
238
- "draft_vocab_size" : DRAFT_VOCAB_SIZE ,
292
+ "draft_vocab_size" : model . config . vocab_size ,
239
293
}
240
294
)
295
+ self .args .draft_vocab_size = model .config .vocab_size
241
296
mtsp .convert (model , [("eagle" , self .args .eagle_config )])
242
297
model .eval ()
243
- self ._print_model_placement (model )
244
298
return model
245
299
246
- def load_student_model (self ):
300
+ def prepare_student_model (self ):
247
301
"""Load student model on a single device and keep needed modules from teacher."""
248
302
# Load to CPU first to avoid OOM
249
303
model = AutoModelForCausalLM .from_pretrained (
250
- MODEL_PATH , torch_dtype = "auto" , device_map = "cpu"
304
+ self . args . model_path , torch_dtype = "auto" , device_map = "cpu"
251
305
)
252
306
# Hidden size and vocab size must match base model
253
307
self .args .eagle_config ["eagle_architecture_config" ].update (
254
308
{
255
309
"hidden_size" : model .config .hidden_size ,
256
310
"vocab_size" : model .config .vocab_size ,
257
- "draft_vocab_size" : DRAFT_VOCAB_SIZE ,
311
+ "draft_vocab_size" : model . config . vocab_size ,
258
312
}
259
313
)
314
+ self .args .draft_vocab_size = model .config .vocab_size
260
315
mtsp .convert (
261
316
model ,
262
317
[("eagle" , self .args .eagle_config )],
263
318
)
264
- if model .config .vocab_size > DRAFT_VOCAB_SIZE :
265
- model_name = os .path .basename (os .path .normpath (MODEL_PATH ))
266
- vocab_cache_path = os .path .join ("draft_vocab_cache" , model_name , "d2t.pt" )
267
- try :
268
- vocab_cache = torch .load (vocab_cache_path )
269
- assert len (vocab_cache ) == DRAFT_VOCAB_SIZE
270
- model .eagle_module .d2t = vocab_cache
271
- print (f"Loaded draft vocab cache from { vocab_cache_path } ." )
272
- except Exception as e :
273
- raise e
274
319
275
320
# TODO:copy needed modules and del the rest
276
321
model .model ._modules .pop ("layers" )
@@ -283,12 +328,7 @@ def load_student_model(self):
283
328
process_group = self .args .student_pgroup ,
284
329
find_unused_parameters = True ,
285
330
)
286
- optimizer = torch .optim .AdamW (model .parameters (), lr = self .args .lr )
287
- scheduler = get_linear_schedule_with_warmup (
288
- optimizer , num_warmup_steps = 0 , num_training_steps = 117380
289
- )
290
- self ._print_model_placement (model )
291
- return model , optimizer , scheduler
331
+ return model
292
332
293
333
def teacher_step (self , model , inputs ):
294
334
base_model_hidden_states , base_model_logits , _ , _ = model ._base_model_forward (
@@ -341,45 +381,3 @@ def student_step(
341
381
self .optimizer .step ()
342
382
self .scheduler .step ()
343
383
return round (loss .item (), 3 ), train_acc
344
-
345
-
346
- # class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer):
347
- # @property
348
- # def current_rank_devices(self):
349
- # if self.rank == self.args.student_rank:
350
- # return [self.args.student_device]
351
- # else:
352
- # return self.args.teacher_devices
353
-
354
- # def load_teacher_model(self):
355
- # model = AutoModelForCausalLM.from_pretrained(
356
- # MODEL_PATH,
357
- # torch_dtype="auto",
358
- # device_map="sequential",
359
- # max_memory=dict.fromkeys(
360
- # self.args.teacher_devices, "999GiB"
361
- # ), # To use only given devices
362
- # )
363
- # self.args.eagle_config["eagle_architecture_config"].update(
364
- # {
365
- # "hidden_size": model.config.hidden_size,
366
- # "vocab_size": model.config.vocab_size,
367
- # "draft_vocab_size": DRAFT_VOCAB_SIZE,
368
- # }
369
- # )
370
- # mtsp.convert(model, [("eagle", self.args.eagle_config)])
371
-
372
- # if model.config.vocab_size > DRAFT_VOCAB_SIZE:
373
- # model_name = os.path.basename(os.path.normpath(MODEL_PATH))
374
- # vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt")
375
- # try:
376
- # vocab_cache = torch.load(vocab_cache_path)
377
- # assert len(vocab_cache) == DRAFT_VOCAB_SIZE
378
- # model.eagle_module.d2t = vocab_cache
379
- # print(f"Loaded draft vocab cache from {vocab_cache_path}.")
380
- # except Exception as e:
381
- # raise e
382
-
383
- # model.eval()
384
- # self._print_model_placement(model)
385
- # return model
0 commit comments