Skip to content

Commit 7725656

Browse files
committed
Explicity set the dtype for gradient tensor
Signed-off-by: mloh <mloh@nvidia.com>
1 parent f10e4e4 commit 7725656

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

nemo_rl/distributed/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def backward(
220220
seq_size = int(vocab_parallel_logits.shape[1])
221221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222222

223-
grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits)
223+
grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits, dtype=torch.float32)
224224

225225
for chunk_idx in range(num_chunks):
226226
chunk_start = chunk_idx * chunk_size

0 commit comments

Comments
 (0)