@@ -220,7 +220,7 @@ def backward(
220220 seq_size = int (vocab_parallel_logits .shape [1 ])
221221 num_chunks = (seq_size + chunk_size - 1 ) // chunk_size
222222
223- all_grad_input = []
223+ grad_input : torch . Tensor = torch . empty_like ( vocab_parallel_logits )
224224
225225 for chunk_idx in range (num_chunks ):
226226 chunk_start = chunk_idx * chunk_size
@@ -243,13 +243,14 @@ def backward(
243243 num_classes = partition_vocab_size ,
244244 )
245245
246- grad_input = is_chosen .float ().sub_ (softmax_output )
246+ # Inplace index into the preallocated grad_input tensor
247+ grad_input_chunk = grad_input [:, chunk_start :chunk_end , :]
248+
249+ grad_input_chunk .copy_ (is_chosen .float ().sub_ (softmax_output )) # inplace copy
250+ grad_input_chunk .mul_ (grad_output [:, chunk_start :chunk_end ].unsqueeze (dim = - 1 ))
247251
248- grad_input .mul_ (grad_output [:, chunk_start :chunk_end ].unsqueeze (dim = - 1 ))
249-
250- all_grad_input .append (grad_input )
251-
252- grad_input = torch .cat (all_grad_input , dim = 1 )
252+ # Explicitly free before next iteration allocates
253+ del softmax_output , is_chosen , logits
253254
254255 # if you add an argument to the forward method, then you must add a corresponding None here
255256 return grad_input , None , None , None , None , None , None
0 commit comments