Skip to content

Commit 9e676e6

Browse files
cyyeverqubvel
andauthored
[qwen] remove unnecessary CUDA sync in qwen2_5_vl (#39870)
Signed-off-by: cyy <[email protected]> Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent 392be3b commit 9e676e6

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def forward(
325325

326326
if self.config._attn_implementation == "flash_attention_2":
327327
# Flash Attention 2: Use cu_seqlens for variable length attention
328-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
328+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
329329
attn_output, _ = attention_interface(
330330
self,
331331
query_states,

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ def forward(
592592
query_states = query_states.transpose(0, 1).unsqueeze(0)
593593
key_states = key_states.transpose(0, 1).unsqueeze(0)
594594
value_states = value_states.transpose(0, 1).unsqueeze(0)
595-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
595+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
596596

597597
attention_interface: Callable = eager_attention_forward
598598
if self.config._attn_implementation != "eager":
@@ -927,7 +927,7 @@ def forward(
927927

928928
if self.config._attn_implementation == "flash_attention_2":
929929
# Flash Attention 2: Use cu_seqlens for variable length attention
930-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
930+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
931931
attn_output, _ = attention_interface(
932932
self,
933933
query_states,

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,7 +1619,7 @@ def forward(
16191619
query_states = query_states.transpose(0, 1).unsqueeze(0)
16201620
key_states = key_states.transpose(0, 1).unsqueeze(0)
16211621
value_states = value_states.transpose(0, 1).unsqueeze(0)
1622-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
1622+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
16231623

16241624
attention_interface: Callable = eager_attention_forward
16251625
if self.config._attn_implementation != "eager":
@@ -1928,7 +1928,7 @@ def forward(
19281928

19291929
if self.config._attn_implementation == "flash_attention_2":
19301930
# Flash Attention 2: Use cu_seqlens for variable length attention
1931-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
1931+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
19321932
attn_output, _ = attention_interface(
19331933
self,
19341934
query_states,

src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def forward(
245245

246246
if self.config._attn_implementation == "flash_attention_2":
247247
# Flash Attention 2: Use cu_seqlens for variable length attention
248-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
248+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
249249
attn_output, _ = attention_interface(
250250
self,
251251
query_states,

src/transformers/models/qwen2_vl/modeling_qwen2_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def forward(
363363

364364
if self.config._attn_implementation == "flash_attention_2":
365365
# Flash Attention 2: Use cu_seqlens for variable length attention
366-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
366+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
367367
attn_output, _ = attention_interface(
368368
self,
369369
query_states,

0 commit comments

Comments
 (0)