3
3
from torch .autograd import Function
4
4
from torch .distributed import ProcessGroup
5
5
from torch .nn import CrossEntropyLoss
6
+ from torch .nn .functional import log_softmax
6
7
7
8
from colossalai .shardformer .layer ._operation import reduce_forward
8
9
from colossalai .shardformer .shard import ShardConfig
9
10
10
11
from .utils import is_share_sp_tp
11
12
12
- __all__ = ["DistCrossEntropy" , "cross_entropy_1d" , "dist_cross_entropy" ]
13
+ __all__ = [
14
+ "DistCrossEntropy" ,
15
+ "cross_entropy_1d" ,
16
+ "dist_cross_entropy" ,
17
+ "DistLogProb" ,
18
+ "dist_log_prob_1d" ,
19
+ "dist_log_prob" ,
20
+ ]
13
21
14
22
_IGNORE_IDX = - 100
15
23
@@ -137,6 +145,98 @@ def backward(ctx, grad_output):
137
145
return grad_logits , None , None , None , None , None , None
138
146
139
147
148
+ class DistLogProb (Function ):
149
+ r"""
150
+ Overwrite the forward and backward function to calculate the log prob before gather
151
+
152
+ Args:
153
+ Function (:class:`torch.autograd.Function`): default
154
+ """
155
+
156
+ @staticmethod
157
+ def forward (
158
+ ctx ,
159
+ vocab_logits : torch .Tensor ,
160
+ target : torch .Tensor ,
161
+ process_group : ProcessGroup ,
162
+ vocab_size : int ,
163
+ dtype = torch .float32 ,
164
+ ):
165
+
166
+ ##################
167
+ # Step1:Find the global maximum value of logits
168
+ ##################
169
+ logits_max = torch .max (vocab_logits , dim = - 1 )[0 ]
170
+ handle = dist .all_reduce (logits_max , op = dist .ReduceOp .MAX , group = process_group , async_op = True )
171
+
172
+ ##################
173
+ # Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
174
+ # For accleration, we overlap Step 2 and Step 3
175
+ ##################
176
+ rank = dist .get_rank (group = process_group )
177
+ world_size = dist .get_world_size (group = process_group )
178
+ if vocab_size is None :
179
+ partition_vocab_size = vocab_logits .size ()[- 1 ]
180
+ global_vocab_size = partition_vocab_size * world_size
181
+ else :
182
+ global_vocab_size = vocab_size
183
+ partition_vocab_size = global_vocab_size // world_size
184
+ # down and up threshold for local logits
185
+ delta = (global_vocab_size + world_size - 1 ) // world_size
186
+ down_threshold = rank * delta
187
+ up_threshold = down_threshold + delta
188
+ if up_threshold > global_vocab_size :
189
+ up_threshold = global_vocab_size
190
+ # mask
191
+ mask = (target < down_threshold ) | (target >= up_threshold )
192
+ masked_target = target .clone () - down_threshold
193
+ masked_target [mask ] = 0
194
+ masked_target_1d = masked_target .view (- 1 ).contiguous ()
195
+ handle .wait ()
196
+
197
+ ##################
198
+ # Step3:Calculate global summation exp logits
199
+ ##################
200
+ vocab_logits = vocab_logits - logits_max .unsqueeze (dim = - 1 )
201
+ exp_logits = torch .exp (vocab_logits )
202
+ sum_exp_logits = torch .sum (exp_logits , dim = - 1 , dtype = torch .float32 ) # local summation exp logits
203
+ dist .all_reduce (sum_exp_logits , op = dist .ReduceOp .SUM , group = process_group )
204
+
205
+ ##################
206
+ # Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
207
+ ##################
208
+ log_probs = vocab_logits - torch .log (sum_exp_logits .unsqueeze (dim = - 1 )) # cal log_softmax
209
+ log_probs = log_probs .gather (dim = - 1 , index = masked_target .unsqueeze (- 1 ))
210
+ log_probs [mask .unsqueeze (- 1 )] = 0 # set masked val to zero
211
+ dist .all_reduce (log_probs , op = dist .ReduceOp .SUM , group = process_group )
212
+
213
+ ctx .save_for_backward (exp_logits , mask , masked_target_1d , sum_exp_logits )
214
+ ctx .dtype = dtype
215
+ return log_probs
216
+
217
+ @staticmethod
218
+ def backward (ctx , grad_output ):
219
+ exp_logits , mask , masked_target_1d , sum_exp_logits = ctx .saved_tensors
220
+ ##################
221
+ # Step1:Find the global sofmax value
222
+ ##################
223
+ softmax_logits = exp_logits / sum_exp_logits .unsqueeze (dim = - 1 )
224
+
225
+ ##################
226
+ # Step2:Update softmax value based on local target index
227
+ ##################
228
+ partion_vocab_size = softmax_logits .shape [- 1 ]
229
+ softmax_logits_2d = softmax_logits .view (- 1 , partion_vocab_size )
230
+ update = 1.0 - mask .view (- 1 ).float ().to (ctx .dtype )
231
+ softmax_logits_2d [torch .arange (0 , softmax_logits_2d .shape [0 ]), masked_target_1d ] -= update
232
+
233
+ ##################
234
+ # Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
235
+ ##################
236
+ grad_logits = - softmax_logits .mul_ (grad_output )
237
+ return grad_logits , None , None , None , None , None , None
238
+
239
+
140
240
def cross_entropy_1d (
141
241
vocab_logits : torch .Tensor ,
142
242
labels : torch .Tensor ,
@@ -149,6 +249,16 @@ def cross_entropy_1d(
149
249
return DistCrossEntropy .apply (vocab_logits , labels , ignore_index , process_group , vocab_size , dtype , mode )
150
250
151
251
252
+ def dist_log_prob_1d (
253
+ vocab_logits : torch .Tensor ,
254
+ labels : torch .Tensor ,
255
+ process_group : ProcessGroup = None ,
256
+ vocab_size : int = None ,
257
+ dtype : torch .dtype = None ,
258
+ ) -> torch .Tensor :
259
+ return DistLogProb .apply (vocab_logits , labels , process_group , vocab_size , dtype )
260
+
261
+
152
262
def dist_cross_entropy (
153
263
labels : torch .Tensor , # [B, S] or [B, S, Vocab_size]
154
264
logits : torch .Tensor , # [B, S, Vocab_size]
@@ -243,3 +353,41 @@ def dist_cross_entropy(
243
353
loss , num_nonzero = loss [0 ], loss [1 ].detach ()
244
354
loss = (loss / num_nonzero ).squeeze ()
245
355
return loss
356
+
357
+
358
+ def dist_log_prob (
359
+ labels : torch .Tensor , # [B, S] or [B, S, Vocab_size]
360
+ logits : torch .Tensor , # [B, S, Vocab_size]
361
+ shard_config : ShardConfig ,
362
+ vocab_size : int ,
363
+ dtype : torch .dtype ,
364
+ seq_dim : int = 1 ,
365
+ ) -> torch .Tensor :
366
+ """
367
+ Helper to compute log prob for most shardformer models supporting PP, TP.
368
+ """
369
+ # Split labels if not gather output
370
+ parallel_output = shard_config .parallel_output
371
+ is_tp = shard_config .enable_tensor_parallelism
372
+
373
+ # TODO:support sp
374
+ labels = labels [..., 1 :]
375
+ logits = logits [..., :- 1 , :]
376
+ labels = labels .contiguous ()
377
+ logits = logits .contiguous ()
378
+ assert labels .shape == logits .shape [:- 1 ], f"label shape { labels .shape } does not match logit shape { logits .shape } "
379
+
380
+ # Flatten the tokens
381
+ if is_tp and parallel_output :
382
+ log_prob = dist_log_prob_1d (
383
+ logits ,
384
+ labels ,
385
+ process_group = shard_config .tensor_parallel_process_group ,
386
+ vocab_size = vocab_size ,
387
+ dtype = dtype ,
388
+ )
389
+ else :
390
+ log_prob = log_softmax (logits )
391
+ log_prob = log_prob .gather (dim = - 1 , index = labels .unsqueeze (- 1 ))
392
+
393
+ return log_prob
0 commit comments