Skip to content

Commit 35cc9a8

Browse files
committed
polish
Signed-off-by: h-guo18 <[email protected]>
1 parent 5daa239 commit 35cc9a8

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949

5050
class BaseDistillTrainer:
5151
"""
52-
Base class for distillation trainer. Initalized and called on every rank.
52+
Base distill trainer with basic training loop and overlapped teacher and student steps.
53+
Initalized and called on every rank.
5354
Args:
5455
rank: rank of the current process
5556
args: arguments
@@ -252,6 +253,8 @@ def train(self):
252253

253254

254255
class EagleTPTrainer(BaseDistillTrainer):
256+
"""A subclass of BaseDistillTrainer for online eagle training, with base model TP and student DDP."""
257+
255258
def __init__(self, rank, args, tokenizer, dataloader):
256259
# Load eagle config
257260
args.eagle_config = EAGLE3_DEFAULT_CFG["config"]

examples/speculative_decoding/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def main():
9292
# TODO: add sanity check for args
9393

9494
def set_ranks(args):
95-
# TODO(hg): add "no-parallel", "MP", "FSDP".
95+
# TODO(hg): This is for TP-DDP setting only. Add "no-parallel", "MP", "FSDP".
9696
args.world_size = len(args.teacher_devices) + len(args.student_devices)
9797
args.student_ranks = list(range(len(args.student_devices)))
9898
args.teacher_ranks = list(

0 commit comments

Comments
 (0)