2626mto .enable_huggingface_checkpointing ()
2727
2828# Hyperparameters for profiling
29- EPOCHS = 20
30- LOG_INTERVAL = 25
29+ EPOCHS = 1
30+ LOG_INTERVAL = 1
3131SAVE_INTERVAL = 20000
3232# VALIDATE_INTERVAL = 20
3333
@@ -48,6 +48,7 @@ class BaseDistillTrainer:
4848 def __init__ (self , rank , args , tokenizer , distill_metadata : DistillMetadata ):
4949 self .rank = rank
5050 args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
51+ args .student_pgroup = dist .new_group (ranks = args .student_ranks )
5152 self .args = args
5253 self .tokenizer = tokenizer
5354 self .distill_metadata = distill_metadata
@@ -57,17 +58,15 @@ def _print_model_placement(self, module):
5758 print (f"(Rank { self .rank } ) { name } ---> { param .device } " )
5859
5960 @property
60- def current_rank_devices (self ):
61+ def current_rank_device (self ):
6162 pass
6263
6364 def _reset_all_mem_stats (self ):
64- for d in self .current_rank_devices :
65- torch .cuda .reset_max_memory_allocated (d )
65+ torch .cuda .reset_max_memory_allocated (self .current_rank_device )
6666
6767 def _print_mem_stats (self ):
68- for d in self .current_rank_devices :
69- max_mem = torch .cuda .max_memory_allocated (d )
70- print (f"GPU { d } : Max memory allocated: { max_mem / 1024 ** 3 :.2f} GB" )
68+ max_mem = torch .cuda .max_memory_allocated (self .current_rank_device )
69+ print (f"GPU { self .current_rank_device } : Max memory allocated: { max_mem / 1024 ** 3 :.2f} GB" )
7170
7271 @abstractmethod
7372 def load_teacher_model (self ):
@@ -86,7 +85,7 @@ def student_step(self, *args, **kwargs):
8685 pass
8786
8887 def save_pretrained (self , path = None ):
89- if self .rank == self .args .student_rank :
88+ if self .rank == self .args .student_ranks [ 0 ] :
9089 path = self .args .out_path if path is None else path
9190 self .model .save_pretrained (path )
9291 self .tokenizer .save_pretrained (path )
@@ -96,24 +95,24 @@ def _check_valid_message(self, message: dict[str, torch.Tensor]):
9695 # Check if keys and length match between message and distill_metadata
9796 if set (message .keys ()) != set (self .distill_metadata .keys ()):
9897 raise ValueError (
99- f"Message keys from teacher : { set (message .keys ())} \n "
98+ f"Message keys: { set (message .keys ())} \n "
10099 f"do not match expected keys { set (self .distill_metadata .keys ())} "
101100 )
102101 if len (message ) != len (self .distill_metadata ):
103102 raise ValueError (
104- f"Message length from teacher : { len (message )} \n "
103+ f"Message length: { len (message )} \n "
105104 f"does not match expected { len (self .distill_metadata )} "
106105 )
107106 for k , v in message .items ():
108107 if v .shape != self .distill_metadata [k ][0 ] or v .dtype != self .distill_metadata [k ][1 ]:
109108 raise ValueError (
110- f"Invalid message from teacher . { k } has shape { v .shape } and dtype { v .dtype } , \n "
109+ f"Invalid message. { k } has shape { v .shape } and dtype { v .dtype } , \n "
111110 f"expected { self .distill_metadata [k ]} "
112111 )
113112
114113 def _init_student_recv_buffer (self ):
115114 self .student_recv_buffer = {
116- k : torch .empty (v [0 ], device = self .args . student_device , dtype = v [1 ])
115+ k : torch .empty (v [0 ], device = self .current_rank_device , dtype = v [1 ])
117116 for k , v in self .distill_metadata .items ()
118117 }
119118
@@ -131,12 +130,16 @@ def _get_distill_kwargs(self):
131130 def _send_to_student (self , teacher_outputs ):
132131 if self .rank != self .args .teacher_ranks [0 ]:
133132 return
134- self ._check_valid_message (teacher_outputs )
135- reqs = [
136- dist .isend (buffer , dst = self .args .student_rank ) for buffer in teacher_outputs .values ()
137- ]
138- for req in reqs :
139- req .wait ()
133+ # TODO: use broadcast
134+ assert len (teacher_outputs ) == len (self .args .student_ranks ), (
135+ f"Number of teacher outputs { len (teacher_outputs )} does not \
136+ match number of student ranks { len (self .args .student_ranks )} "
137+ )
138+ for s in self .args .student_ranks :
139+ self ._check_valid_message (teacher_outputs [s ])
140+ reqs = [dist .isend (buffer , dst = s ) for buffer in teacher_outputs [s ].values ()]
141+ for req in reqs :
142+ req .wait ()
140143
141144 # def _validate_ar(self, steps=3, osl=20, num_samples=20):
142145 # if self.rank != self.args.student_rank:
@@ -161,7 +164,7 @@ def train(self, dataloader):
161164 """Main training entrance of the composed model."""
162165 self ._reset_all_mem_stats ()
163166
164- if self .rank == self .args .student_rank :
167+ if self .rank in self .args .student_ranks :
165168 import wandb
166169
167170 wandb .login ()
0 commit comments