We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent f10e4e4 commit ac29781Copy full SHA for ac29781
nemo_rl/distributed/model_utils.py
@@ -220,7 +220,9 @@ def backward(
220
seq_size = int(vocab_parallel_logits.shape[1])
221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222
223
- grad_input: torch.Tensor = torch.empty_like(vocab_parallel_logits)
+ grad_input: torch.Tensor = torch.empty_like(
224
+ vocab_parallel_logits, dtype=torch.float32
225
+ )
226
227
for chunk_idx in range(num_chunks):
228
chunk_start = chunk_idx * chunk_size
0 commit comments