Skip to content

Commit ac29781

Browse files
committed
Explicity set the dtype for gradient tensor
Signed-off-by: mloh <mloh@nvidia.com>
1 parent f10e4e4 commit ac29781

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

nemo_rl/distributed/model_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ def backward(
220220
seq_size = int(vocab_parallel_logits.shape[1])
221221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222222

223-
grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits)
223+
grad_input: torch.Tensor = torch.empty_like(
224+
vocab_parallel_logits, dtype=torch.float32
225+
)
224226

225227
for chunk_idx in range(num_chunks):
226228
chunk_start = chunk_idx * chunk_size

0 commit comments

Comments
 (0)