|
19 | 19 |
|
20 | 20 | import torch
|
21 | 21 | import torch.distributed as dist
|
| 22 | +from torch.distributed.device_mesh import DeviceMesh |
22 | 23 | from tqdm import tqdm
|
| 24 | +from transformers import AutoModelForCausalLM |
| 25 | +from transformers.optimization import get_linear_schedule_with_warmup |
23 | 26 |
|
24 | 27 | import modelopt.torch.opt as mto
|
| 28 | +import modelopt.torch.speculative as mtsp |
25 | 29 |
|
26 | 30 | mto.enable_huggingface_checkpointing()
|
27 | 31 |
|
28 | 32 | # Hyperparameters for profiling
|
29 |
| -EPOCHS = 10 |
| 33 | +EPOCHS = 1 |
30 | 34 | LOG_INTERVAL = 100
|
31 | 35 | SAVE_INTERVAL = 20000
|
| 36 | +MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| 37 | +DRAFT_VOCAB_SIZE = 32000 |
32 | 38 | # VALIDATE_INTERVAL = 20
|
33 | 39 |
|
34 |
| -# We define the distill signal from teacher as the map of variable name to its shape and dtype. |
| 40 | +# Shape and dtype description of the distillation signal |
35 | 41 | DistillMetadata = dict[str, tuple[torch.Size, torch.dtype]]
|
36 | 42 |
|
37 | 43 |
|
@@ -208,3 +214,172 @@ def train(self, dataloader):
|
208 | 214 | dist.barrier()
|
209 | 215 | # clean up processess
|
210 | 216 | dist.destroy_process_group()
|
| 217 | + |
| 218 | + |
| 219 | +class EagleTPTrainer(BaseDistillTrainer): |
| 220 | + @property |
| 221 | + def current_rank_device(self): |
| 222 | + if self.rank in self.args.student_ranks: |
| 223 | + return self.args.student_devices[self.rank] |
| 224 | + else: |
| 225 | + return self.args.teacher_devices[self.rank - len(self.args.student_ranks)] |
| 226 | + |
| 227 | + def load_teacher_model(self): |
| 228 | + model = AutoModelForCausalLM.from_pretrained( |
| 229 | + MODEL_PATH, |
| 230 | + torch_dtype="auto", |
| 231 | + tp_plan="auto", |
| 232 | + device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"), |
| 233 | + ) |
| 234 | + self.args.eagle_config["eagle_architecture_config"].update( |
| 235 | + { |
| 236 | + "hidden_size": model.config.hidden_size, |
| 237 | + "vocab_size": model.config.vocab_size, |
| 238 | + "draft_vocab_size": DRAFT_VOCAB_SIZE, |
| 239 | + } |
| 240 | + ) |
| 241 | + mtsp.convert(model, [("eagle", self.args.eagle_config)]) |
| 242 | + model.eval() |
| 243 | + self._print_model_placement(model) |
| 244 | + return model |
| 245 | + |
| 246 | + def load_student_model(self): |
| 247 | + """Load student model on a single device and keep needed modules from teacher.""" |
| 248 | + # Load to CPU first to avoid OOM |
| 249 | + model = AutoModelForCausalLM.from_pretrained( |
| 250 | + MODEL_PATH, torch_dtype="auto", device_map="cpu" |
| 251 | + ) |
| 252 | + # Hidden size and vocab size must match base model |
| 253 | + self.args.eagle_config["eagle_architecture_config"].update( |
| 254 | + { |
| 255 | + "hidden_size": model.config.hidden_size, |
| 256 | + "vocab_size": model.config.vocab_size, |
| 257 | + "draft_vocab_size": DRAFT_VOCAB_SIZE, |
| 258 | + } |
| 259 | + ) |
| 260 | + mtsp.convert( |
| 261 | + model, |
| 262 | + [("eagle", self.args.eagle_config)], |
| 263 | + ) |
| 264 | + if model.config.vocab_size > DRAFT_VOCAB_SIZE: |
| 265 | + model_name = os.path.basename(os.path.normpath(MODEL_PATH)) |
| 266 | + vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") |
| 267 | + try: |
| 268 | + vocab_cache = torch.load(vocab_cache_path) |
| 269 | + assert len(vocab_cache) == DRAFT_VOCAB_SIZE |
| 270 | + model.eagle_module.d2t = vocab_cache |
| 271 | + print(f"Loaded draft vocab cache from {vocab_cache_path}.") |
| 272 | + except Exception as e: |
| 273 | + raise e |
| 274 | + |
| 275 | + # TODO:copy needed modules and del the rest |
| 276 | + model.model._modules.pop("layers") |
| 277 | + model.to(self.current_rank_device) |
| 278 | + |
| 279 | + model.train() |
| 280 | + model = torch.nn.parallel.DistributedDataParallel( |
| 281 | + model, |
| 282 | + device_ids=[self.current_rank_device], |
| 283 | + process_group=self.args.student_pgroup, |
| 284 | + find_unused_parameters=True, |
| 285 | + ) |
| 286 | + optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr) |
| 287 | + scheduler = get_linear_schedule_with_warmup( |
| 288 | + optimizer, num_warmup_steps=0, num_training_steps=117380 |
| 289 | + ) |
| 290 | + self._print_model_placement(model) |
| 291 | + return model, optimizer, scheduler |
| 292 | + |
| 293 | + def teacher_step(self, model, inputs): |
| 294 | + base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward( |
| 295 | + **inputs, |
| 296 | + freeze_base_model=True, |
| 297 | + past_key_values=None, |
| 298 | + ) |
| 299 | + # aux_hidden_states could be on multiple devices. Gather them and cat. |
| 300 | + aux_hidden_states = torch.cat( |
| 301 | + [t.to(base_model_logits.device) for t in model.pop_aux_hidden_states()], dim=-1 |
| 302 | + ) |
| 303 | + base_model_hidden_states = base_model_hidden_states.chunk(len(self.args.student_ranks)) |
| 304 | + base_model_logits = base_model_logits.chunk(len(self.args.student_ranks)) |
| 305 | + aux_hidden_states = aux_hidden_states.chunk(len(self.args.student_ranks)) |
| 306 | + |
| 307 | + return [ |
| 308 | + { |
| 309 | + "base_model_hidden_states": base_model_hidden_states[i], |
| 310 | + "aux_hidden_states": aux_hidden_states[i], |
| 311 | + "base_model_logits": base_model_logits[i], |
| 312 | + } |
| 313 | + for i in range(len(self.args.student_ranks)) |
| 314 | + ] |
| 315 | + |
| 316 | + def student_step( |
| 317 | + self, |
| 318 | + inputs, |
| 319 | + base_model_hidden_states, |
| 320 | + aux_hidden_states, |
| 321 | + base_model_logits, |
| 322 | + ): |
| 323 | + self.optimizer.zero_grad() |
| 324 | + # Second stage forward using the unified model |
| 325 | + inputs = {k: v.chunk(len(self.args.student_ranks))[self.rank] for k, v in inputs.items()} |
| 326 | + output = self.model( |
| 327 | + **inputs, |
| 328 | + # providing base model outputs to bypass the base model forward. |
| 329 | + base_model_outputs={ |
| 330 | + "base_model_hidden_states": base_model_hidden_states, |
| 331 | + "aux_hidden_states": aux_hidden_states.clone().detach(), |
| 332 | + "base_model_logits": base_model_logits.clone().detach(), |
| 333 | + }, |
| 334 | + ) |
| 335 | + loss = output.loss |
| 336 | + # print(f"Rank {self.rank} loss: {loss.item()}") |
| 337 | + train_acc = output.train_acc |
| 338 | + |
| 339 | + # Backward |
| 340 | + loss.backward() |
| 341 | + self.optimizer.step() |
| 342 | + self.scheduler.step() |
| 343 | + return round(loss.item(), 3), train_acc |
| 344 | + |
| 345 | + |
| 346 | +# class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer): |
| 347 | +# @property |
| 348 | +# def current_rank_devices(self): |
| 349 | +# if self.rank == self.args.student_rank: |
| 350 | +# return [self.args.student_device] |
| 351 | +# else: |
| 352 | +# return self.args.teacher_devices |
| 353 | + |
| 354 | +# def load_teacher_model(self): |
| 355 | +# model = AutoModelForCausalLM.from_pretrained( |
| 356 | +# MODEL_PATH, |
| 357 | +# torch_dtype="auto", |
| 358 | +# device_map="sequential", |
| 359 | +# max_memory=dict.fromkeys( |
| 360 | +# self.args.teacher_devices, "999GiB" |
| 361 | +# ), # To use only given devices |
| 362 | +# ) |
| 363 | +# self.args.eagle_config["eagle_architecture_config"].update( |
| 364 | +# { |
| 365 | +# "hidden_size": model.config.hidden_size, |
| 366 | +# "vocab_size": model.config.vocab_size, |
| 367 | +# "draft_vocab_size": DRAFT_VOCAB_SIZE, |
| 368 | +# } |
| 369 | +# ) |
| 370 | +# mtsp.convert(model, [("eagle", self.args.eagle_config)]) |
| 371 | + |
| 372 | +# if model.config.vocab_size > DRAFT_VOCAB_SIZE: |
| 373 | +# model_name = os.path.basename(os.path.normpath(MODEL_PATH)) |
| 374 | +# vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt") |
| 375 | +# try: |
| 376 | +# vocab_cache = torch.load(vocab_cache_path) |
| 377 | +# assert len(vocab_cache) == DRAFT_VOCAB_SIZE |
| 378 | +# model.eagle_module.d2t = vocab_cache |
| 379 | +# print(f"Loaded draft vocab cache from {vocab_cache_path}.") |
| 380 | +# except Exception as e: |
| 381 | +# raise e |
| 382 | + |
| 383 | +# model.eval() |
| 384 | +# self._print_model_placement(model) |
| 385 | +# return model |
0 commit comments