Skip to content

Commit a2e7a2a

Browse files
committed
key is required, but none is accepted
Signed-off-by: Maanu Grover <maanug@nvidia.com>
1 parent 268442c commit a2e7a2a

File tree

1 file changed

+6
-6
lines changed
  • src/megatron/bridge/data/datasets

1 file changed

+6
-6
lines changed

src/megatron/bridge/data/datasets/sft.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,8 @@ def collate_fn(self, batch):
696696
if not self.get_attention_mask_from_fusion:
697697
attention_mask = [self._create_attention_mask(max_length) for _ in batch]
698698
attention_mask = torch.stack(attention_mask)
699+
else:
700+
attention_mask = None
699701
position_ids = [list(range(max_length)) for _ in batch]
700702
position_ids = torch.LongTensor(position_ids)
701703
input_ids = torch.LongTensor(
@@ -716,11 +718,9 @@ def collate_fn(self, batch):
716718
"answers": answers,
717719
"metadata": metadata,
718720
"token_count": token_count,
721+
"attention_mask": attention_mask,
719722
}
720723

721-
if not self.get_attention_mask_from_fusion:
722-
processed_batch["attention_mask"] = attention_mask
723-
724724
return processed_batch
725725

726726

@@ -1068,6 +1068,8 @@ def collate_fn(self, batch):
10681068
if not self.get_attention_mask_from_fusion:
10691069
attention_mask = [self._create_attention_mask(max_length) for _ in batch]
10701070
attention_mask = torch.stack(attention_mask)
1071+
else:
1072+
attention_mask = None
10711073
position_ids = [list(range(max_length)) for _ in batch]
10721074
position_ids = torch.LongTensor(position_ids)
10731075
input_ids = torch.LongTensor(
@@ -1088,9 +1090,7 @@ def collate_fn(self, batch):
10881090
"context_lengths": context_lengths,
10891091
"answers": answers,
10901092
"metadata": metadata,
1093+
"attention_mask": attention_mask,
10911094
}
10921095

1093-
if not self.get_attention_mask_from_fusion:
1094-
processed_batch["attention_mask"] = attention_mask
1095-
10961096
return processed_batch

0 commit comments

Comments
 (0)