1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ import json
1516import os
1617
1718os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
2627
2728import modelopt .torch .opt as mto
2829import modelopt .torch .speculative as mtsp
30+ from modelopt .torch .speculative .config import EAGLE3_DEFAULT_CFG
2931
3032mto .enable_huggingface_checkpointing ()
3133
3234# Hyperparameters for profiling
3335EPOCHS = 1
3436LOG_INTERVAL = 100
3537SAVE_INTERVAL = 20000
36- MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37- DRAFT_VOCAB_SIZE = 32000
3838# VALIDATE_INTERVAL = 20
3939
4040# Shape and dtype description of the distillation signal
@@ -51,13 +51,21 @@ class BaseDistillTrainer:
5151 student_step: student step function.
5252 """
5353
54- def __init__ (self , rank , args , tokenizer , distill_metadata : DistillMetadata ):
54+ def __init__ (self , rank , args , tokenizer ):
5555 self .rank = rank
5656 args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
5757 args .student_pgroup = dist .new_group (ranks = args .student_ranks )
5858 self .args = args
5959 self .tokenizer = tokenizer
60- self .distill_metadata = distill_metadata
60+ if rank in args .student_ranks :
61+ self .model = self .prepare_student_model ()
62+ self .optimizer = torch .optim .AdamW (self .model .parameters (), lr = self .args .lr )
63+ self .scheduler = get_linear_schedule_with_warmup (
64+ self .optimizer , num_warmup_steps = 0 , num_training_steps = 117380
65+ )
66+ else :
67+ self .model = self .prepare_teacher_model ()
68+ self ._print_model_placement (self .model )
6169
6270 def _print_model_placement (self , module ):
6371 for name , param in module .named_parameters ():
@@ -67,6 +75,10 @@ def _print_model_placement(self, module):
6775 def current_rank_device (self ):
6876 pass
6977
78+ @property
79+ def distill_metadata (self ):
80+ pass
81+
7082 def _reset_all_mem_stats (self ):
7183 torch .cuda .reset_max_memory_allocated (self .current_rank_device )
7284
@@ -162,7 +174,6 @@ def train(self, dataloader):
162174 project = os .environ ["WANDB_PROJECT" ],
163175 config = {"epochs" : EPOCHS , "lr" : self .args .lr , "batch_size" : self .args .batch_size },
164176 ) as run :
165- self .model , self .optimizer , self .scheduler = self .load_student_model ()
166177 self ._init_student_recv_buffer ()
167178 wandb .watch (self .model , log = "all" )
168179
@@ -198,7 +209,6 @@ def train(self, dataloader):
198209 )
199210
200211 else :
201- self .model = self .load_teacher_model ()
202212 # Inference Loop
203213 for epoch in range (EPOCHS ):
204214 for i , batch in enumerate (dataloader ):
@@ -217,16 +227,60 @@ def train(self, dataloader):
217227
218228
219229class EagleTPTrainer (BaseDistillTrainer ):
230+ def __init__ (self , rank , args , tokenizer ):
231+ args .eagle_config = EAGLE3_DEFAULT_CFG ["config" ]
232+ if args .eagle_config_path :
233+ with open (args .eagle_config_path ) as f :
234+ custom_config = json .load (f )
235+ args .eagle_config ["eagle_architecture_config" ].update (custom_config )
236+
237+ super ().__init__ (rank , args , tokenizer )
238+
220239 @property
221240 def current_rank_device (self ):
222241 if self .rank in self .args .student_ranks :
223242 return self .args .student_devices [self .rank ]
224243 else :
225244 return self .args .teacher_devices [self .rank - len (self .args .student_ranks )]
226245
227- def load_teacher_model (self ):
246+ @property
247+ def distill_metadata (self ) -> DistillMetadata :
248+ return {
249+ "base_model_hidden_states" : (
250+ torch .Size (
251+ [
252+ int (self .args .batch_size / len (self .args .student_ranks )),
253+ self .args .training_seq_len ,
254+ 2048 ,
255+ ]
256+ ),
257+ torch .bfloat16 ,
258+ ),
259+ "aux_hidden_states" : (
260+ torch .Size (
261+ [
262+ int (self .args .batch_size / len (self .args .student_ranks )),
263+ self .args .training_seq_len ,
264+ 2048 * 3 ,
265+ ]
266+ ),
267+ torch .bfloat16 ,
268+ ),
269+ "base_model_logits" : (
270+ torch .Size (
271+ [
272+ int (self .args .batch_size / len (self .args .student_ranks )),
273+ self .args .training_seq_len ,
274+ self .args .draft_vocab_size ,
275+ ]
276+ ),
277+ torch .bfloat16 ,
278+ ),
279+ }
280+
281+ def prepare_teacher_model (self ):
228282 model = AutoModelForCausalLM .from_pretrained (
229- MODEL_PATH ,
283+ self . args . model_path ,
230284 torch_dtype = "auto" ,
231285 tp_plan = "auto" ,
232286 device_mesh = DeviceMesh .from_group (self .args .teacher_pgroup , "cuda" ),
@@ -235,42 +289,33 @@ def load_teacher_model(self):
235289 {
236290 "hidden_size" : model .config .hidden_size ,
237291 "vocab_size" : model .config .vocab_size ,
238- "draft_vocab_size" : DRAFT_VOCAB_SIZE ,
292+ "draft_vocab_size" : model . config . vocab_size ,
239293 }
240294 )
295+ self .args .draft_vocab_size = model .config .vocab_size
241296 mtsp .convert (model , [("eagle" , self .args .eagle_config )])
242297 model .eval ()
243- self ._print_model_placement (model )
244298 return model
245299
246- def load_student_model (self ):
300+ def prepare_student_model (self ):
247301 """Load student model on a single device and keep needed modules from teacher."""
248302 # Load to CPU first to avoid OOM
249303 model = AutoModelForCausalLM .from_pretrained (
250- MODEL_PATH , torch_dtype = "auto" , device_map = "cpu"
304+ self . args . model_path , torch_dtype = "auto" , device_map = "cpu"
251305 )
252306 # Hidden size and vocab size must match base model
253307 self .args .eagle_config ["eagle_architecture_config" ].update (
254308 {
255309 "hidden_size" : model .config .hidden_size ,
256310 "vocab_size" : model .config .vocab_size ,
257- "draft_vocab_size" : DRAFT_VOCAB_SIZE ,
311+ "draft_vocab_size" : model . config . vocab_size ,
258312 }
259313 )
314+ self .args .draft_vocab_size = model .config .vocab_size
260315 mtsp .convert (
261316 model ,
262317 [("eagle" , self .args .eagle_config )],
263318 )
264- if model .config .vocab_size > DRAFT_VOCAB_SIZE :
265- model_name = os .path .basename (os .path .normpath (MODEL_PATH ))
266- vocab_cache_path = os .path .join ("draft_vocab_cache" , model_name , "d2t.pt" )
267- try :
268- vocab_cache = torch .load (vocab_cache_path )
269- assert len (vocab_cache ) == DRAFT_VOCAB_SIZE
270- model .eagle_module .d2t = vocab_cache
271- print (f"Loaded draft vocab cache from { vocab_cache_path } ." )
272- except Exception as e :
273- raise e
274319
275320 # TODO:copy needed modules and del the rest
276321 model .model ._modules .pop ("layers" )
@@ -283,12 +328,7 @@ def load_student_model(self):
283328 process_group = self .args .student_pgroup ,
284329 find_unused_parameters = True ,
285330 )
286- optimizer = torch .optim .AdamW (model .parameters (), lr = self .args .lr )
287- scheduler = get_linear_schedule_with_warmup (
288- optimizer , num_warmup_steps = 0 , num_training_steps = 117380
289- )
290- self ._print_model_placement (model )
291- return model , optimizer , scheduler
331+ return model
292332
293333 def teacher_step (self , model , inputs ):
294334 base_model_hidden_states , base_model_logits , _ , _ = model ._base_model_forward (
@@ -341,45 +381,3 @@ def student_step(
341381 self .optimizer .step ()
342382 self .scheduler .step ()
343383 return round (loss .item (), 3 ), train_acc
344-
345-
346- # class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer):
347- # @property
348- # def current_rank_devices(self):
349- # if self.rank == self.args.student_rank:
350- # return [self.args.student_device]
351- # else:
352- # return self.args.teacher_devices
353-
354- # def load_teacher_model(self):
355- # model = AutoModelForCausalLM.from_pretrained(
356- # MODEL_PATH,
357- # torch_dtype="auto",
358- # device_map="sequential",
359- # max_memory=dict.fromkeys(
360- # self.args.teacher_devices, "999GiB"
361- # ), # To use only given devices
362- # )
363- # self.args.eagle_config["eagle_architecture_config"].update(
364- # {
365- # "hidden_size": model.config.hidden_size,
366- # "vocab_size": model.config.vocab_size,
367- # "draft_vocab_size": DRAFT_VOCAB_SIZE,
368- # }
369- # )
370- # mtsp.convert(model, [("eagle", self.args.eagle_config)])
371-
372- # if model.config.vocab_size > DRAFT_VOCAB_SIZE:
373- # model_name = os.path.basename(os.path.normpath(MODEL_PATH))
374- # vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt")
375- # try:
376- # vocab_cache = torch.load(vocab_cache_path)
377- # assert len(vocab_cache) == DRAFT_VOCAB_SIZE
378- # model.eagle_module.d2t = vocab_cache
379- # print(f"Loaded draft vocab cache from {vocab_cache_path}.")
380- # except Exception as e:
381- # raise e
382-
383- # model.eval()
384- # self._print_model_placement(model)
385- # return model
0 commit comments