Skip to content

Commit 87d0879

Browse files
committed
up
1 parent bc40971 commit 87d0879

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
540540
return torch.empty_like(query), query.new_empty(lse_shape)
541541

542542

543-
@_custom_op("flash_attn_3_hub_func", mutates_args=(), device_types="cuda")
543+
@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
544544
def _wrapped_flash_attn_3_hub(
545545
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
546546
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -549,7 +549,7 @@ def _wrapped_flash_attn_3_hub(
549549
return out, lse
550550

551551

552-
@_register_fake("flash_attn_3_hub_func")
552+
@_register_fake("vllm_flash_attn3::_flash_attn_forward")
553553
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
554554
batch_size, seq_len, num_heads, head_dim = query.shape
555555
lse_shape = (batch_size, seq_len, num_heads)

src/diffusers/utils/kernels_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def _get_fa3_from_hub():
1111
from kernels import get_kernel
1212

1313
try:
14-
vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3)
15-
return vllm_flash_attn3
14+
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
15+
return flash_attn_3_hub
1616
except Exception as e:
1717
raise e

0 commit comments

Comments
 (0)