Skip to content

Commit 8cedc09

Browse files
committed
up
1 parent 823ee2c commit 8cedc09

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,6 +1357,7 @@ def _flash_attention(
13571357
@_AttentionBackendRegistry.register(
13581358
AttentionBackendName.FLASH_HUB,
13591359
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1360+
supports_context_parallel=False,
13601361
)
13611362
def _flash_attention_hub(
13621363
query: torch.Tensor,
@@ -1368,9 +1369,6 @@ def _flash_attention_hub(
13681369
return_lse: bool = False,
13691370
_parallel_config: Optional["ParallelConfig"] = None,
13701371
) -> torch.Tensor:
1371-
if _parallel_config:
1372-
raise NotImplementedError(f"{AttentionBackendName.FLASH_HUB.value} is not implemented for parallelism yet.")
1373-
13741372
lse = None
13751373
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
13761374
out = func(
@@ -1469,6 +1467,7 @@ def _flash_attention_3(
14691467
@_AttentionBackendRegistry.register(
14701468
AttentionBackendName._FLASH_3_HUB,
14711469
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
1470+
supports_context_parallel=False,
14721471
)
14731472
def _flash_attention_3_hub(
14741473
query: torch.Tensor,
@@ -1993,9 +1992,6 @@ def _sage_attention_hub(
19931992
return_lse: bool = False,
19941993
_parallel_config: Optional["ParallelConfig"] = None,
19951994
) -> torch.Tensor:
1996-
if _parallel_config:
1997-
raise NotImplementedError(f"{AttentionBackendName.SAGE_HUB.value} is not implemented for parallelism yet.")
1998-
19991995
lse = None
20001996
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
20011997
if _parallel_config is None:

0 commit comments

Comments
 (0)