Skip to content

Commit 0e70be4

Browse files
committed
add student model ddp
Signed-off-by: h-guo18 <[email protected]>
1 parent b70b418 commit 0e70be4

File tree

2 files changed

+132
-101
lines changed

2 files changed

+132
-101
lines changed

examples/speculative_decoding/distill_trainer.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
mto.enable_huggingface_checkpointing()
2727

2828
# Hyperparameters for profiling
29-
EPOCHS = 20
30-
LOG_INTERVAL = 25
29+
EPOCHS = 1
30+
LOG_INTERVAL = 1
3131
SAVE_INTERVAL = 20000
3232
# VALIDATE_INTERVAL = 20
3333

@@ -48,6 +48,7 @@ class BaseDistillTrainer:
4848
def __init__(self, rank, args, tokenizer, distill_metadata: DistillMetadata):
4949
self.rank = rank
5050
args.teacher_pgroup = dist.new_group(ranks=args.teacher_ranks)
51+
args.student_pgroup = dist.new_group(ranks=args.student_ranks)
5152
self.args = args
5253
self.tokenizer = tokenizer
5354
self.distill_metadata = distill_metadata
@@ -57,17 +58,15 @@ def _print_model_placement(self, module):
5758
print(f"(Rank {self.rank}) {name} ---> {param.device} ")
5859

5960
@property
60-
def current_rank_devices(self):
61+
def current_rank_device(self):
6162
pass
6263

6364
def _reset_all_mem_stats(self):
64-
for d in self.current_rank_devices:
65-
torch.cuda.reset_max_memory_allocated(d)
65+
torch.cuda.reset_max_memory_allocated(self.current_rank_device)
6666

6767
def _print_mem_stats(self):
68-
for d in self.current_rank_devices:
69-
max_mem = torch.cuda.max_memory_allocated(d)
70-
print(f"GPU {d}: Max memory allocated: {max_mem / 1024**3:.2f} GB")
68+
max_mem = torch.cuda.max_memory_allocated(self.current_rank_device)
69+
print(f"GPU {self.current_rank_device}: Max memory allocated: {max_mem / 1024**3:.2f} GB")
7170

7271
@abstractmethod
7372
def load_teacher_model(self):
@@ -86,7 +85,7 @@ def student_step(self, *args, **kwargs):
8685
pass
8786

8887
def save_pretrained(self, path=None):
89-
if self.rank == self.args.student_rank:
88+
if self.rank == self.args.student_ranks[0]:
9089
path = self.args.out_path if path is None else path
9190
self.model.save_pretrained(path)
9291
self.tokenizer.save_pretrained(path)
@@ -96,24 +95,24 @@ def _check_valid_message(self, message: dict[str, torch.Tensor]):
9695
# Check if keys and length match between message and distill_metadata
9796
if set(message.keys()) != set(self.distill_metadata.keys()):
9897
raise ValueError(
99-
f"Message keys from teacher: {set(message.keys())} \n"
98+
f"Message keys: {set(message.keys())} \n"
10099
f"do not match expected keys {set(self.distill_metadata.keys())}"
101100
)
102101
if len(message) != len(self.distill_metadata):
103102
raise ValueError(
104-
f"Message length from teacher: {len(message)} \n"
103+
f"Message length: {len(message)} \n"
105104
f"does not match expected {len(self.distill_metadata)}"
106105
)
107106
for k, v in message.items():
108107
if v.shape != self.distill_metadata[k][0] or v.dtype != self.distill_metadata[k][1]:
109108
raise ValueError(
110-
f"Invalid message from teacher. {k} has shape {v.shape} and dtype {v.dtype}, \n"
109+
f"Invalid message. {k} has shape {v.shape} and dtype {v.dtype}, \n"
111110
f"expected {self.distill_metadata[k]}"
112111
)
113112

114113
def _init_student_recv_buffer(self):
115114
self.student_recv_buffer = {
116-
k: torch.empty(v[0], device=self.args.student_device, dtype=v[1])
115+
k: torch.empty(v[0], device=self.current_rank_device, dtype=v[1])
117116
for k, v in self.distill_metadata.items()
118117
}
119118

@@ -131,12 +130,16 @@ def _get_distill_kwargs(self):
131130
def _send_to_student(self, teacher_outputs):
132131
if self.rank != self.args.teacher_ranks[0]:
133132
return
134-
self._check_valid_message(teacher_outputs)
135-
reqs = [
136-
dist.isend(buffer, dst=self.args.student_rank) for buffer in teacher_outputs.values()
137-
]
138-
for req in reqs:
139-
req.wait()
133+
# TODO: use broadcast
134+
assert len(teacher_outputs) == len(self.args.student_ranks), (
135+
f"Number of teacher outputs {len(teacher_outputs)} does not \
136+
match number of student ranks {len(self.args.student_ranks)}"
137+
)
138+
for s in self.args.student_ranks:
139+
self._check_valid_message(teacher_outputs[s])
140+
reqs = [dist.isend(buffer, dst=s) for buffer in teacher_outputs[s].values()]
141+
for req in reqs:
142+
req.wait()
140143

141144
# def _validate_ar(self, steps=3, osl=20, num_samples=20):
142145
# if self.rank != self.args.student_rank:
@@ -161,7 +164,7 @@ def train(self, dataloader):
161164
"""Main training entrance of the composed model."""
162165
self._reset_all_mem_stats()
163166

164-
if self.rank == self.args.student_rank:
167+
if self.rank in self.args.student_ranks:
165168
import wandb
166169

167170
wandb.login()

0 commit comments

Comments
 (0)