Skip to content

Commit bfc8f67

Browse files
committed
Set attention_mask to None by default.
Signed-off-by: Jonas Yang <joyang@nvidia.com>
1 parent f521459 commit bfc8f67

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,6 @@ def train(
631631
], # TODO: this is a WAR for sequence packing, we should fix this. Without this, backward will fail when TP is enabled.
632632
)
633633
seq_len = input_ids.shape[1]
634-
attention_mask = None
635634
flash_attn_kwargs = get_flash_attention_kwargs(
636635
input_lengths=mb["input_lengths"],
637636
)
@@ -640,16 +639,13 @@ def train(
640639
input_ids = mb.get("input_ids").cuda()
641640
batch_size, seq_len = input_ids.shape
642641

643-
attention_mask = torch.ones(
644-
(batch_size, seq_len),
645-
dtype=torch.bool,
646-
device=input_ids.device,
647-
)
648642
position_ids = torch.arange(
649643
seq_len, device=input_ids.device
650644
).repeat(batch_size, 1)
651645
flash_attn_kwargs = {}
652646

647+
attention_mask = None
648+
653649
# add vlm kwargs to model call
654650
vlm_kwargs = mb.get_multimodal_dict(
655651
as_tensors=True, device=input_ids.device
@@ -952,7 +948,6 @@ def get_logprobs(
952948
return_attention_mask=False,
953949
)
954950
seq_len = input_ids.shape[1]
955-
attention_mask = None
956951
flash_attn_kwargs = get_flash_attention_kwargs(
957952
input_lengths=input_lengths,
958953
)
@@ -972,15 +967,11 @@ def get_logprobs(
972967
).repeat(batch_size, 1)
973968
flash_attn_kwargs = {}
974969

975-
# DTensor requires the casual attention kernel to hit,
976-
# yet our attention mask above is not always all 1s
977-
# this is fine because we mask with the actual attention mask
978-
# later, but for input it has to be all 1s
979-
attention_mask = torch.ones(
980-
(batch_size, seq_len),
981-
dtype=torch.bool,
982-
device=input_ids.device,
983-
)
970+
# DTensor requires the casual attention kernel to hit,
971+
# yet our attention mask above is not always all 1s
972+
# this is fine because we mask with the actual attention mask
973+
# later, but for input it has to be all 1s
974+
attention_mask = None
984975

985976
# if there are multimodal kwargs, we don't need to add position_ids (computed internally)
986977
if len(vlm_kwargs) > 0:

0 commit comments

Comments
 (0)