File tree Expand file tree Collapse file tree 5 files changed +7
-7
lines changed Expand file tree Collapse file tree 5 files changed +7
-7
lines changed Original file line number Diff line number Diff line change @@ -325,7 +325,7 @@ def forward(
325
325
326
326
if self .config ._attn_implementation == "flash_attention_2" :
327
327
# Flash Attention 2: Use cu_seqlens for variable length attention
328
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max (). item ()
328
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ()
329
329
attn_output , _ = attention_interface (
330
330
self ,
331
331
query_states ,
Original file line number Diff line number Diff line change @@ -592,7 +592,7 @@ def forward(
592
592
query_states = query_states .transpose (0 , 1 ).unsqueeze (0 )
593
593
key_states = key_states .transpose (0 , 1 ).unsqueeze (0 )
594
594
value_states = value_states .transpose (0 , 1 ).unsqueeze (0 )
595
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max (). item ()
595
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ()
596
596
597
597
attention_interface : Callable = eager_attention_forward
598
598
if self .config ._attn_implementation != "eager" :
@@ -927,7 +927,7 @@ def forward(
927
927
928
928
if self .config ._attn_implementation == "flash_attention_2" :
929
929
# Flash Attention 2: Use cu_seqlens for variable length attention
930
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max (). item ()
930
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ()
931
931
attn_output , _ = attention_interface (
932
932
self ,
933
933
query_states ,
Original file line number Diff line number Diff line change @@ -1619,7 +1619,7 @@ def forward(
1619
1619
query_states = query_states .transpose (0 , 1 ).unsqueeze (0 )
1620
1620
key_states = key_states .transpose (0 , 1 ).unsqueeze (0 )
1621
1621
value_states = value_states .transpose (0 , 1 ).unsqueeze (0 )
1622
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max (). item ()
1622
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ()
1623
1623
1624
1624
attention_interface : Callable = eager_attention_forward
1625
1625
if self .config ._attn_implementation != "eager" :
@@ -1928,7 +1928,7 @@ def forward(
1928
1928
1929
1929
if self .config ._attn_implementation == "flash_attention_2" :
1930
1930
# Flash Attention 2: Use cu_seqlens for variable length attention
1931
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max (). item ()
1931
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ()
1932
1932
attn_output , _ = attention_interface (
1933
1933
self ,
1934
1934
query_states ,
Original file line number Diff line number Diff line change @@ -245,7 +245,7 @@ def forward(
245
245
246
246
if self .config ._attn_implementation == "flash_attention_2" :
247
247
# Flash Attention 2: Use cu_seqlens for variable length attention
248
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max (). item ()
248
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ()
249
249
attn_output , _ = attention_interface (
250
250
self ,
251
251
query_states ,
Original file line number Diff line number Diff line change @@ -363,7 +363,7 @@ def forward(
363
363
364
364
if self .config ._attn_implementation == "flash_attention_2" :
365
365
# Flash Attention 2: Use cu_seqlens for variable length attention
366
- max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max (). item ()
366
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ()
367
367
attn_output , _ = attention_interface (
368
368
self ,
369
369
query_states ,
You can’t perform that action at this time.
0 commit comments