@@ -631,7 +631,6 @@ def train(
631631 ], # TODO: this is a WAR for sequence packing, we should fix this. Without this, backward will fail when TP is enabled.
632632 )
633633 seq_len = input_ids .shape [1 ]
634- attention_mask = None
635634 flash_attn_kwargs = get_flash_attention_kwargs (
636635 input_lengths = mb ["input_lengths" ],
637636 )
@@ -640,16 +639,13 @@ def train(
640639 input_ids = mb .get ("input_ids" ).cuda ()
641640 batch_size , seq_len = input_ids .shape
642641
643- attention_mask = torch .ones (
644- (batch_size , seq_len ),
645- dtype = torch .bool ,
646- device = input_ids .device ,
647- )
648642 position_ids = torch .arange (
649643 seq_len , device = input_ids .device
650644 ).repeat (batch_size , 1 )
651645 flash_attn_kwargs = {}
652646
647+ attention_mask = None
648+
653649 # add vlm kwargs to model call
654650 vlm_kwargs = mb .get_multimodal_dict (
655651 as_tensors = True , device = input_ids .device
@@ -952,7 +948,6 @@ def get_logprobs(
952948 return_attention_mask = False ,
953949 )
954950 seq_len = input_ids .shape [1 ]
955- attention_mask = None
956951 flash_attn_kwargs = get_flash_attention_kwargs (
957952 input_lengths = input_lengths ,
958953 )
@@ -972,15 +967,11 @@ def get_logprobs(
972967 ).repeat (batch_size , 1 )
973968 flash_attn_kwargs = {}
974969
975- # DTensor requires the casual attention kernel to hit,
976- # yet our attention mask above is not always all 1s
977- # this is fine because we mask with the actual attention mask
978- # later, but for input it has to be all 1s
979- attention_mask = torch .ones (
980- (batch_size , seq_len ),
981- dtype = torch .bool ,
982- device = input_ids .device ,
983- )
970+ # DTensor requires the casual attention kernel to hit,
971+ # yet our attention mask above is not always all 1s
972+ # this is fine because we mask with the actual attention mask
973+ # later, but for input it has to be all 1s
974+ attention_mask = None
984975
985976 # if there are multimodal kwargs, we don't need to add position_ids (computed internally)
986977 if len (vlm_kwargs ) > 0 :
0 commit comments