Skip to content

Commit cab076f

Browse files
committed
Intitalize gradient as zero instead of empty
Signed-off-by: mloh <mloh@nvidia.com>
1 parent cfa7aa3 commit cab076f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

nemo_rl/distributed/model_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def backward(
220220
seq_size = int(vocab_parallel_logits.shape[1])
221221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222222

223-
grad_input: torch.Tensor = torch.empty_like(
223+
grad_input: torch.Tensor = torch.zeros_like(
224224
vocab_parallel_logits, dtype=torch.float32
225225
)
226226

@@ -334,7 +334,7 @@ def backward(
334334
B, S, V_local = vocab_parallel_logits.shape
335335
num_chunks = (int(S) + chunk_size - 1) // chunk_size
336336

337-
grad_input: torch.Tensor = torch.empty_like(
337+
grad_input: torch.Tensor = torch.zeros_like(
338338
vocab_parallel_logits, dtype=torch.float32
339339
)
340340

0 commit comments

Comments
 (0)