@@ -696,6 +696,8 @@ def collate_fn(self, batch):
696696 if not self .get_attention_mask_from_fusion :
697697 attention_mask = [self ._create_attention_mask (max_length ) for _ in batch ]
698698 attention_mask = torch .stack (attention_mask )
699+ else :
700+ attention_mask = None
699701 position_ids = [list (range (max_length )) for _ in batch ]
700702 position_ids = torch .LongTensor (position_ids )
701703 input_ids = torch .LongTensor (
@@ -716,11 +718,9 @@ def collate_fn(self, batch):
716718 "answers" : answers ,
717719 "metadata" : metadata ,
718720 "token_count" : token_count ,
721+ "attention_mask" : attention_mask ,
719722 }
720723
721- if not self .get_attention_mask_from_fusion :
722- processed_batch ["attention_mask" ] = attention_mask
723-
724724 return processed_batch
725725
726726
@@ -1068,6 +1068,8 @@ def collate_fn(self, batch):
10681068 if not self .get_attention_mask_from_fusion :
10691069 attention_mask = [self ._create_attention_mask (max_length ) for _ in batch ]
10701070 attention_mask = torch .stack (attention_mask )
1071+ else :
1072+ attention_mask = None
10711073 position_ids = [list (range (max_length )) for _ in batch ]
10721074 position_ids = torch .LongTensor (position_ids )
10731075 input_ids = torch .LongTensor (
@@ -1088,9 +1090,7 @@ def collate_fn(self, batch):
10881090 "context_lengths" : context_lengths ,
10891091 "answers" : answers ,
10901092 "metadata" : metadata ,
1093+ "attention_mask" : attention_mask ,
10911094 }
10921095
1093- if not self .get_attention_mask_from_fusion :
1094- processed_batch ["attention_mask" ] = attention_mask
1095-
10961096 return processed_batch
0 commit comments