We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f10e4e4 commit 7725656Copy full SHA for 7725656
nemo_rl/distributed/model_utils.py
@@ -220,7 +220,7 @@ def backward(
220
seq_size = int(vocab_parallel_logits.shape[1])
221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222
223
- grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits)
+ grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits, dtype=torch.float32)
224
225
for chunk_idx in range(num_chunks):
226
chunk_start = chunk_idx * chunk_size
0 commit comments