26
26
mto .enable_huggingface_checkpointing ()
27
27
28
28
# Hyperparameters for profiling
29
- EPOCHS = 20
30
- LOG_INTERVAL = 25
29
+ EPOCHS = 1
30
+ LOG_INTERVAL = 1
31
31
SAVE_INTERVAL = 20000
32
32
# VALIDATE_INTERVAL = 20
33
33
@@ -48,6 +48,7 @@ class BaseDistillTrainer:
48
48
def __init__ (self , rank , args , tokenizer , distill_metadata : DistillMetadata ):
49
49
self .rank = rank
50
50
args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
51
+ args .student_pgroup = dist .new_group (ranks = args .student_ranks )
51
52
self .args = args
52
53
self .tokenizer = tokenizer
53
54
self .distill_metadata = distill_metadata
@@ -57,17 +58,15 @@ def _print_model_placement(self, module):
57
58
print (f"(Rank { self .rank } ) { name } ---> { param .device } " )
58
59
59
60
@property
60
- def current_rank_devices (self ):
61
+ def current_rank_device (self ):
61
62
pass
62
63
63
64
def _reset_all_mem_stats (self ):
64
- for d in self .current_rank_devices :
65
- torch .cuda .reset_max_memory_allocated (d )
65
+ torch .cuda .reset_max_memory_allocated (self .current_rank_device )
66
66
67
67
def _print_mem_stats (self ):
68
- for d in self .current_rank_devices :
69
- max_mem = torch .cuda .max_memory_allocated (d )
70
- print (f"GPU { d } : Max memory allocated: { max_mem / 1024 ** 3 :.2f} GB" )
68
+ max_mem = torch .cuda .max_memory_allocated (self .current_rank_device )
69
+ print (f"GPU { self .current_rank_device } : Max memory allocated: { max_mem / 1024 ** 3 :.2f} GB" )
71
70
72
71
@abstractmethod
73
72
def load_teacher_model (self ):
@@ -86,7 +85,7 @@ def student_step(self, *args, **kwargs):
86
85
pass
87
86
88
87
def save_pretrained (self , path = None ):
89
- if self .rank == self .args .student_rank :
88
+ if self .rank == self .args .student_ranks [ 0 ] :
90
89
path = self .args .out_path if path is None else path
91
90
self .model .save_pretrained (path )
92
91
self .tokenizer .save_pretrained (path )
@@ -96,24 +95,24 @@ def _check_valid_message(self, message: dict[str, torch.Tensor]):
96
95
# Check if keys and length match between message and distill_metadata
97
96
if set (message .keys ()) != set (self .distill_metadata .keys ()):
98
97
raise ValueError (
99
- f"Message keys from teacher : { set (message .keys ())} \n "
98
+ f"Message keys: { set (message .keys ())} \n "
100
99
f"do not match expected keys { set (self .distill_metadata .keys ())} "
101
100
)
102
101
if len (message ) != len (self .distill_metadata ):
103
102
raise ValueError (
104
- f"Message length from teacher : { len (message )} \n "
103
+ f"Message length: { len (message )} \n "
105
104
f"does not match expected { len (self .distill_metadata )} "
106
105
)
107
106
for k , v in message .items ():
108
107
if v .shape != self .distill_metadata [k ][0 ] or v .dtype != self .distill_metadata [k ][1 ]:
109
108
raise ValueError (
110
- f"Invalid message from teacher . { k } has shape { v .shape } and dtype { v .dtype } , \n "
109
+ f"Invalid message. { k } has shape { v .shape } and dtype { v .dtype } , \n "
111
110
f"expected { self .distill_metadata [k ]} "
112
111
)
113
112
114
113
def _init_student_recv_buffer (self ):
115
114
self .student_recv_buffer = {
116
- k : torch .empty (v [0 ], device = self .args . student_device , dtype = v [1 ])
115
+ k : torch .empty (v [0 ], device = self .current_rank_device , dtype = v [1 ])
117
116
for k , v in self .distill_metadata .items ()
118
117
}
119
118
@@ -131,12 +130,16 @@ def _get_distill_kwargs(self):
131
130
def _send_to_student (self , teacher_outputs ):
132
131
if self .rank != self .args .teacher_ranks [0 ]:
133
132
return
134
- self ._check_valid_message (teacher_outputs )
135
- reqs = [
136
- dist .isend (buffer , dst = self .args .student_rank ) for buffer in teacher_outputs .values ()
137
- ]
138
- for req in reqs :
139
- req .wait ()
133
+ # TODO: use broadcast
134
+ assert len (teacher_outputs ) == len (self .args .student_ranks ), (
135
+ f"Number of teacher outputs { len (teacher_outputs )} does not \
136
+ match number of student ranks { len (self .args .student_ranks )} "
137
+ )
138
+ for s in self .args .student_ranks :
139
+ self ._check_valid_message (teacher_outputs [s ])
140
+ reqs = [dist .isend (buffer , dst = s ) for buffer in teacher_outputs [s ].values ()]
141
+ for req in reqs :
142
+ req .wait ()
140
143
141
144
# def _validate_ar(self, steps=3, osl=20, num_samples=20):
142
145
# if self.rank != self.args.student_rank:
@@ -161,7 +164,7 @@ def train(self, dataloader):
161
164
"""Main training entrance of the composed model."""
162
165
self ._reset_all_mem_stats ()
163
166
164
- if self .rank == self .args .student_rank :
167
+ if self .rank in self .args .student_ranks :
165
168
import wandb
166
169
167
170
wandb .login ()
0 commit comments