We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 50689b5 + 0e994e9 commit bc14324Copy full SHA for bc14324
nemo_rl/models/policy/workers/megatron_policy_worker.py
@@ -916,15 +916,6 @@ def train(
916
self.model.train()
917
918
with ctx:
919
- # dim 1 is always assumed to be the sequence dim, sanity check this here
920
- sequence_dim = 1
921
- seq_dim_size = data["input_ids"].shape[sequence_dim]
922
- for k, v in data.items():
923
- if torch.is_tensor(v) and len(v.shape) > 1:
924
- assert v.shape[sequence_dim] == seq_dim_size, (
925
- f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}"
926
- )
927
-
928
all_mb_metrics = []
929
losses = []
930
total_num_microbatches = 0
0 commit comments