We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent cfa7aa3 commit cab076fCopy full SHA for cab076f
nemo_rl/distributed/model_utils.py
@@ -220,7 +220,7 @@ def backward(
220
seq_size = int(vocab_parallel_logits.shape[1])
221
num_chunks = (seq_size + chunk_size - 1) // chunk_size
222
223
- grad_input: torch.Tensor = torch.empty_like(
+ grad_input: torch.Tensor = torch.zeros_like(
224
vocab_parallel_logits, dtype=torch.float32
225
)
226
@@ -334,7 +334,7 @@ def backward(
334
B, S, V_local = vocab_parallel_logits.shape
335
num_chunks = (int(S) + chunk_size - 1) // chunk_size
336
337
338
339
340
0 commit comments