@@ -1357,6 +1357,7 @@ def _flash_attention(
13571357@_AttentionBackendRegistry .register (
13581358 AttentionBackendName .FLASH_HUB ,
13591359 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
1360+ supports_context_parallel = False ,
13601361)
13611362def _flash_attention_hub (
13621363 query : torch .Tensor ,
@@ -1368,9 +1369,6 @@ def _flash_attention_hub(
13681369 return_lse : bool = False ,
13691370 _parallel_config : Optional ["ParallelConfig" ] = None ,
13701371) -> torch .Tensor :
1371- if _parallel_config :
1372- raise NotImplementedError (f"{ AttentionBackendName .FLASH_HUB .value } is not implemented for parallelism yet." )
1373-
13741372 lse = None
13751373 func = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_HUB ].kernel_fn
13761374 out = func (
@@ -1469,6 +1467,7 @@ def _flash_attention_3(
14691467@_AttentionBackendRegistry .register (
14701468 AttentionBackendName ._FLASH_3_HUB ,
14711469 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
1470+ supports_context_parallel = False ,
14721471)
14731472def _flash_attention_3_hub (
14741473 query : torch .Tensor ,
@@ -1993,9 +1992,6 @@ def _sage_attention_hub(
19931992 return_lse : bool = False ,
19941993 _parallel_config : Optional ["ParallelConfig" ] = None ,
19951994) -> torch .Tensor :
1996- if _parallel_config :
1997- raise NotImplementedError (f"{ AttentionBackendName .SAGE_HUB .value } is not implemented for parallelism yet." )
1998-
19991995 lse = None
20001996 func = _HUB_KERNELS_REGISTRY [AttentionBackendName .SAGE_HUB ].kernel_fn
20011997 if _parallel_config is None :
0 commit comments