Skip to content

Commit d1ce6b9

Browse files
committed
Reduce memory footprint for ChunkedDistributedLogProb
Signed-off-by: mloh <mloh@nvidia.com>
1 parent 3bdb852 commit d1ce6b9

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

nemo_rl/distributed/model_utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def backward(
220220
seq_size = int(vocab_parallel_logits.shape[1])
221221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222222

223-
all_grad_input = []
223+
grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits)
224224

225225
for chunk_idx in range(num_chunks):
226226
chunk_start = chunk_idx * chunk_size
@@ -243,13 +243,18 @@ def backward(
243243
num_classes=partition_vocab_size,
244244
)
245245

246-
grad_input = is_chosen.float().sub_(softmax_output)
247-
248-
grad_input.mul_(grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1))
249-
250-
all_grad_input.append(grad_input)
246+
# Inplace index into the preallocated grad_input tensor
247+
grad_input_chunk = grad_input[:, chunk_start:chunk_end, :]
248+
249+
grad_input_chunk.copy_(
250+
is_chosen.float().sub_(softmax_output)
251+
) # inplace copy
252+
grad_input_chunk.mul_(
253+
grad_output[:, chunk_start:chunk_end].unsqueeze(dim=-1)
254+
)
251255

252-
grad_input = torch.cat(all_grad_input, dim=1)
256+
# Explicitly free before next iteration allocates
257+
del softmax_output, is_chosen, logits
253258

254259
# if you add an argument to the forward method, then you must add a corresponding None here
255260
return grad_input, None, None, None, None, None, None

0 commit comments

Comments
 (0)