Skip to content

Commit 5ae4479

Browse files
committed
refactor
Signed-off-by: h-guo18 <[email protected]>
1 parent 6c62207 commit 5ae4479

File tree

2 files changed

+101
-137
lines changed

2 files changed

+101
-137
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import json
1516
import os
1617

1718
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -26,15 +27,14 @@
2627

2728
import modelopt.torch.opt as mto
2829
import modelopt.torch.speculative as mtsp
30+
from modelopt.torch.speculative.config import EAGLE3_DEFAULT_CFG
2931

3032
mto.enable_huggingface_checkpointing()
3133

3234
# Hyperparameters for profiling
3335
EPOCHS = 1
3436
LOG_INTERVAL = 100
3537
SAVE_INTERVAL = 20000
36-
MODEL_PATH = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37-
DRAFT_VOCAB_SIZE = 32000
3838
# VALIDATE_INTERVAL = 20
3939

4040
# Shape and dtype description of the distillation signal
@@ -51,13 +51,21 @@ class BaseDistillTrainer:
5151
student_step: student step function.
5252
"""
5353

54-
def __init__(self, rank, args, tokenizer, distill_metadata: DistillMetadata):
54+
def __init__(self, rank, args, tokenizer):
5555
self.rank = rank
5656
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
5757
args.student_pgroup = dist.new_group(ranks=args.student_ranks)
5858
self.args = args
5959
self.tokenizer = tokenizer
60-
self.distill_metadata = distill_metadata
60+
if rank in args.student_ranks:
61+
self.model = self.prepare_student_model()
62+
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr)
63+
self.scheduler = get_linear_schedule_with_warmup(
64+
self.optimizer, num_warmup_steps=0, num_training_steps=117380
65+
)
66+
else:
67+
self.model = self.prepare_teacher_model()
68+
self._print_model_placement(self.model)
6169

6270
def _print_model_placement(self, module):
6371
for name, param in module.named_parameters():
@@ -67,6 +75,10 @@ def _print_model_placement(self, module):
6775
def current_rank_device(self):
6876
pass
6977

78+
@property
79+
def distill_metadata(self):
80+
pass
81+
7082
def _reset_all_mem_stats(self):
7183
torch.cuda.reset_max_memory_allocated(self.current_rank_device)
7284

@@ -162,7 +174,6 @@ def train(self, dataloader):
162174
project=os.environ["WANDB_PROJECT"],
163175
config={"epochs": EPOCHS, "lr": self.args.lr, "batch_size": self.args.batch_size},
164176
) as run:
165-
self.model, self.optimizer, self.scheduler = self.load_student_model()
166177
self._init_student_recv_buffer()
167178
wandb.watch(self.model, log="all")
168179

@@ -198,7 +209,6 @@ def train(self, dataloader):
198209
)
199210

200211
else:
201-
self.model = self.load_teacher_model()
202212
# Inference Loop
203213
for epoch in range(EPOCHS):
204214
for i, batch in enumerate(dataloader):
@@ -217,16 +227,60 @@ def train(self, dataloader):
217227

218228

219229
class EagleTPTrainer(BaseDistillTrainer):
230+
def __init__(self, rank, args, tokenizer):
231+
args.eagle_config = EAGLE3_DEFAULT_CFG["config"]
232+
if args.eagle_config_path:
233+
with open(args.eagle_config_path) as f:
234+
custom_config = json.load(f)
235+
args.eagle_config["eagle_architecture_config"].update(custom_config)
236+
237+
super().__init__(rank, args, tokenizer)
238+
220239
@property
221240
def current_rank_device(self):
222241
if self.rank in self.args.student_ranks:
223242
return self.args.student_devices[self.rank]
224243
else:
225244
return self.args.teacher_devices[self.rank - len(self.args.student_ranks)]
226245

227-
def load_teacher_model(self):
246+
@property
247+
def distill_metadata(self) -> DistillMetadata:
248+
return {
249+
"base_model_hidden_states": (
250+
torch.Size(
251+
[
252+
int(self.args.batch_size / len(self.args.student_ranks)),
253+
self.args.training_seq_len,
254+
2048,
255+
]
256+
),
257+
torch.bfloat16,
258+
),
259+
"aux_hidden_states": (
260+
torch.Size(
261+
[
262+
int(self.args.batch_size / len(self.args.student_ranks)),
263+
self.args.training_seq_len,
264+
2048 * 3,
265+
]
266+
),
267+
torch.bfloat16,
268+
),
269+
"base_model_logits": (
270+
torch.Size(
271+
[
272+
int(self.args.batch_size / len(self.args.student_ranks)),
273+
self.args.training_seq_len,
274+
self.args.draft_vocab_size,
275+
]
276+
),
277+
torch.bfloat16,
278+
),
279+
}
280+
281+
def prepare_teacher_model(self):
228282
model = AutoModelForCausalLM.from_pretrained(
229-
MODEL_PATH,
283+
self.args.model_path,
230284
torch_dtype="auto",
231285
tp_plan="auto",
232286
device_mesh=DeviceMesh.from_group(self.args.teacher_pgroup, "cuda"),
@@ -235,42 +289,33 @@ def load_teacher_model(self):
235289
{
236290
"hidden_size": model.config.hidden_size,
237291
"vocab_size": model.config.vocab_size,
238-
"draft_vocab_size": DRAFT_VOCAB_SIZE,
292+
"draft_vocab_size": model.config.vocab_size,
239293
}
240294
)
295+
self.args.draft_vocab_size = model.config.vocab_size
241296
mtsp.convert(model, [("eagle", self.args.eagle_config)])
242297
model.eval()
243-
self._print_model_placement(model)
244298
return model
245299

246-
def load_student_model(self):
300+
def prepare_student_model(self):
247301
"""Load student model on a single device and keep needed modules from teacher."""
248302
# Load to CPU first to avoid OOM
249303
model = AutoModelForCausalLM.from_pretrained(
250-
MODEL_PATH, torch_dtype="auto", device_map="cpu"
304+
self.args.model_path, torch_dtype="auto", device_map="cpu"
251305
)
252306
# Hidden size and vocab size must match base model
253307
self.args.eagle_config["eagle_architecture_config"].update(
254308
{
255309
"hidden_size": model.config.hidden_size,
256310
"vocab_size": model.config.vocab_size,
257-
"draft_vocab_size": DRAFT_VOCAB_SIZE,
311+
"draft_vocab_size": model.config.vocab_size,
258312
}
259313
)
314+
self.args.draft_vocab_size = model.config.vocab_size
260315
mtsp.convert(
261316
model,
262317
[("eagle", self.args.eagle_config)],
263318
)
264-
if model.config.vocab_size > DRAFT_VOCAB_SIZE:
265-
model_name = os.path.basename(os.path.normpath(MODEL_PATH))
266-
vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt")
267-
try:
268-
vocab_cache = torch.load(vocab_cache_path)
269-
assert len(vocab_cache) == DRAFT_VOCAB_SIZE
270-
model.eagle_module.d2t = vocab_cache
271-
print(f"Loaded draft vocab cache from {vocab_cache_path}.")
272-
except Exception as e:
273-
raise e
274319

275320
# TODO:copy needed modules and del the rest
276321
model.model._modules.pop("layers")
@@ -283,12 +328,7 @@ def load_student_model(self):
283328
process_group=self.args.student_pgroup,
284329
find_unused_parameters=True,
285330
)
286-
optimizer = torch.optim.AdamW(model.parameters(), lr=self.args.lr)
287-
scheduler = get_linear_schedule_with_warmup(
288-
optimizer, num_warmup_steps=0, num_training_steps=117380
289-
)
290-
self._print_model_placement(model)
291-
return model, optimizer, scheduler
331+
return model
292332

293333
def teacher_step(self, model, inputs):
294334
base_model_hidden_states, base_model_logits, _, _ = model._base_model_forward(
@@ -341,45 +381,3 @@ def student_step(
341381
self.optimizer.step()
342382
self.scheduler.step()
343383
return round(loss.item(), 3), train_acc
344-
345-
346-
# class EagleMPTrainer(EagleTPTrainer, BaseDistillTrainer):
347-
# @property
348-
# def current_rank_devices(self):
349-
# if self.rank == self.args.student_rank:
350-
# return [self.args.student_device]
351-
# else:
352-
# return self.args.teacher_devices
353-
354-
# def load_teacher_model(self):
355-
# model = AutoModelForCausalLM.from_pretrained(
356-
# MODEL_PATH,
357-
# torch_dtype="auto",
358-
# device_map="sequential",
359-
# max_memory=dict.fromkeys(
360-
# self.args.teacher_devices, "999GiB"
361-
# ), # To use only given devices
362-
# )
363-
# self.args.eagle_config["eagle_architecture_config"].update(
364-
# {
365-
# "hidden_size": model.config.hidden_size,
366-
# "vocab_size": model.config.vocab_size,
367-
# "draft_vocab_size": DRAFT_VOCAB_SIZE,
368-
# }
369-
# )
370-
# mtsp.convert(model, [("eagle", self.args.eagle_config)])
371-
372-
# if model.config.vocab_size > DRAFT_VOCAB_SIZE:
373-
# model_name = os.path.basename(os.path.normpath(MODEL_PATH))
374-
# vocab_cache_path = os.path.join("draft_vocab_cache", model_name, "d2t.pt")
375-
# try:
376-
# vocab_cache = torch.load(vocab_cache_path)
377-
# assert len(vocab_cache) == DRAFT_VOCAB_SIZE
378-
# model.eagle_module.d2t = vocab_cache
379-
# print(f"Loaded draft vocab cache from {vocab_cache_path}.")
380-
# except Exception as e:
381-
# raise e
382-
383-
# model.eval()
384-
# self._print_model_placement(model)
385-
# return model

0 commit comments

Comments
 (0)