Skip to content

Commit aaa5b0b

Browse files
committed
Reduce memory footprint for ChunkedDistributedLogProb
1 parent 3bdb852 commit aaa5b0b

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

nemo_rl/distributed/model_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def backward(
220220
seq_size = int(vocab_parallel_logits.shape[1])
221221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222222

223-
all_grad_input = []
223+
grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits)
224224

225225
for chunk_idx in range(num_chunks):
226226
chunk_start = chunk_idx * chunk_size
@@ -243,13 +243,14 @@ def backward(
243243
num_classes=partition_vocab_size,
244244
)
245245

246-
grad_input = is_chosen.float().sub_(softmax_output)
246+
# Inplace index into the preallocated grad_input tensor
247+
grad_input_chunk = grad_input[:, chunk_start:chunk_end, :]
248+
249+
grad_input_chunk.copy_(is_chosen.float().sub_(softmax_output)) # inplace copy
250+
grad_input_chunk.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1))
247251

248-
grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1))
249-
250-
all_grad_input.append(grad_input)
251-
252-
grad_input = torch.cat(all_grad_input, dim=1)
252+
# Explicitly free before next iteration allocates
253+
del softmax_output, is_chosen, logits
253254

254255
# if you add an argument to the forward method, then you must add a corresponding None here
255256
return grad_input, None, None, None, None, None, None

0 commit comments

Comments
 (0)