From 82d20e64a5cc79cd890c56dd2166dcaef29df193 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 14:39:07 +0530 Subject: [PATCH 1/3] up --- src/diffusers/models/attention_dispatch.py | 427 +++++++++++++++++++-- 1 file changed, 394 insertions(+), 33 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ffad94cc7f27..ffda25497643 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -256,6 +256,10 @@ class _HubKernelConfig: function_attr: str revision: Optional[str] = None kernel_fn: Optional[Callable] = None + wrapped_forward_attr: Optional[str] = None + wrapped_backward_attr: Optional[str] = None + wrapped_forward_fn: Optional[Callable] = None + wrapped_backward_fn: Optional[Callable] = None # Registry for hub-based attention kernels @@ -270,7 +274,11 @@ class _HubKernelConfig: # revision="fake-ops-return-probs", ), AttentionBackendName.FLASH_HUB: _HubKernelConfig( - repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None + repo_id="kernels-community/flash-attn2", + function_attr="flash_attn_func", + revision=None, + wrapped_forward_attr="_wrapped_flash_attn_forward", + wrapped_backward_attr="_wrapped_flash_attn_backward", ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None @@ -607,10 +615,15 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: kernel_module = get_kernel(config.repo_id, revision=config.revision) kernel_func = getattr(kernel_module, config.function_attr) - # Cache the downloaded kernel function in the config object config.kernel_fn = kernel_func + if config.wrapped_forward_attr is not None and config.wrapped_forward_attr is not None: + wrapped_forward_fn = getattr(kernel_module, config.wrapped_forward_attr) + wrapped_backward_fn = getattr(kernel_module, config.wrapped_backward_attr) + config.wrapped_forward_fn = wrapped_forward_fn + config.wrapped_backward_fn = wrapped_backward_fn + except Exception as e: logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") raise @@ -969,6 +982,246 @@ def _flash_attention_backward_op( return grad_query, grad_key, grad_value +def _maybe_format_lse_for_context_parallel( + lse: Optional[torch.Tensor], + *, + seq_len: int, + num_heads: int, +) -> Optional[torch.Tensor]: + if lse is None or lse.ndim != 3: + return lse + + if lse.shape[1] == num_heads and lse.shape[2] == seq_len: + lse = lse.permute(0, 2, 1) + + return lse.contiguous() + + +def _flash_attention_hub_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.") + + config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] + wrapped_forward_fn = config.wrapped_forward_fn + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_forward_fn is None or wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` " + "for context parallel execution." + ) + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + window_size = (-1, -1) + softcap = 0.0 + alibi_slopes = None + deterministic = False + grad_enabled = any(x.requires_grad for x in (query, key, value)) + + if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1): + dropout_p = dropout_p if dropout_p > 0 else 1e-30 + + with torch.set_grad_enabled(grad_enabled): + out, lse, S_dmask, rng_state = wrapped_forward_fn( + query, + key, + value, + dropout_p, + scale, + is_causal, + window_size[0], + window_size[1], + softcap, + alibi_slopes, + return_lse, + ) + lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + + if _save_ctx: + ctx.save_for_backward(query, key, value, out, lse, rng_state) + ctx.dropout_p = dropout_p + ctx.scale = scale + ctx.is_causal = is_causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + + return (out, lse) if return_lse else out + + +def _flash_attention_hub_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB] + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution." + ) + + query, key, value, out, lse, rng_state = ctx.saved_tensors + grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + + _ = wrapped_backward_fn( + grad_out, + query, + key, + value, + out, + lse, + grad_query, + grad_key, + grad_value, + ctx.dropout_p, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state, + ) + + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] + + return grad_query, grad_key, grad_value + + +def _flash_attention_3_hub_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, + *, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, +): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.") + if dropout_p != 0.0: + raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.") + + func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn + out = func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + sm_margin=sm_margin, + return_attn_probs=return_lse, + ) + + lse = None + if return_lse: + out, lse = out + lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + + if _save_ctx: + ctx.save_for_backward(query, key, value) + ctx.scale = scale + ctx.is_causal = is_causal + ctx._hub_kernel = func + + return (out, lse) if return_lse else out + + +def _flash_attention_3_hub_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + num_splits: int = 1, + pack_gqa: Optional[bool] = None, + deterministic: bool = False, + sm_margin: int = 0, +): + query, key, value = ctx.saved_tensors + kernel_fn = ctx._hub_kernel + with torch.enable_grad(): + query_r = query.detach().requires_grad_(True) + key_r = key.detach().requires_grad_(True) + value_r = value.detach().requires_grad_(True) + + out = kernel_fn( + q=query_r, + k=key_r, + v=value_r, + softmax_scale=ctx.scale, + causal=ctx.is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=num_splits, + pack_gqa=pack_gqa, + deterministic=deterministic, + sm_margin=sm_margin, + return_attn_probs=False, + ) + if isinstance(out, tuple): + out = out[0] + + grad_query, grad_key, grad_value = torch.autograd.grad( + out, + (query_r, key_r, value_r), + grad_out, + retain_graph=False, + allow_unused=False, + ) + + return grad_query, grad_key, grad_value + + def _sage_attention_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -1015,6 +1268,45 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _sage_attention_hub_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if attn_mask is not None: + raise ValueError("`attn_mask` is not yet supported for Sage attention.") + if dropout_p > 0.0: + raise ValueError("`dropout_p` is not yet supported for Sage attention.") + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for Sage attention.") + + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn + out = func( + q=query, + k=key, + v=value, + tensor_layout="NHD", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + + lse = None + if return_lse: + out, lse, *_ = out + lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + + return (out, lse) if return_lse else out + # ===== Context parallel ===== @@ -1372,7 +1664,7 @@ def _flash_attention( @_AttentionBackendRegistry.register( AttentionBackendName.FLASH_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=False, + supports_context_parallel=True, ) def _flash_attention_hub( query: torch.Tensor, @@ -1386,17 +1678,35 @@ def _flash_attention_hub( ) -> torch.Tensor: lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn - out = func( - q=query, - k=key, - v=value, - dropout_p=dropout_p, - softmax_scale=scale, - causal=is_causal, - return_attn_probs=return_lse, - ) - if return_lse: - out, lse, *_ = out + if _parallel_config is None: + out = func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + return_attn_probs=return_lse, + ) + if return_lse: + out, lse, *_ = out + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + is_causal, + scale, + False, + return_lse, + forward_op=_flash_attention_hub_forward_op, + backward_op=_flash_attention_hub_backward_op, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse = out return (out, lse) if return_lse else out @@ -1539,7 +1849,7 @@ def _flash_attention_3( @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_3_HUB, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=False, + supports_context_parallel=True, ) def _flash_attention_3_hub( query: torch.Tensor, @@ -1553,31 +1863,65 @@ def _flash_attention_3_hub( return_attn_probs: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - if _parallel_config: - raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.") - func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn - out = func( - q=query, - k=key, - v=value, - softmax_scale=scale, - causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, + if _parallel_config is None: + out = func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + return_attn_probs=return_attn_probs, + ) + return (out[0], out[1]) if return_attn_probs else out + + forward_op = functools.partial( + _flash_attention_3_hub_forward_op, window_size=window_size, softcap=softcap, num_splits=1, pack_gqa=None, deterministic=deterministic, sm_margin=0, - return_attn_probs=return_attn_probs, ) - # When `return_attn_probs` is True, the above returns a tuple of - # actual outputs and lse. - return (out[0], out[1]) if return_attn_probs else out + backward_op = functools.partial( + _flash_attention_3_hub_backward_op, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = _templated_context_parallel_attention( + query, + key, + value, + None, + 0.0, + is_causal, + scale, + False, + return_attn_probs, + forward_op=forward_op, + backward_op=backward_op, + _parallel_config=_parallel_config, + ) + if return_attn_probs: + out, lse = out + return out, lse + + return out @_AttentionBackendRegistry.register( @@ -2107,7 +2451,7 @@ def _sage_attention( @_AttentionBackendRegistry.register( AttentionBackendName.SAGE_HUB, constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], - supports_context_parallel=False, + supports_context_parallel=True, ) def _sage_attention_hub( query: torch.Tensor, @@ -2132,6 +2476,23 @@ def _sage_attention_hub( ) if return_lse: out, lse, *_ = out + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + 0.0, + is_causal, + scale, + False, + return_lse, + forward_op=_sage_attention_hub_forward_op, + backward_op=_sage_attention_backward_op, + _parallel_config=_parallel_config, + ) + if return_lse: + out, lse = out return (out, lse) if return_lse else out From 7a8f85b0473eb82db9114537c847a71fd8ab6f5e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 14:59:01 +0530 Subject: [PATCH 2/3] up --- src/diffusers/models/attention_dispatch.py | 35 +++++++++++++++------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ffda25497643..433815d7ed9b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -277,8 +277,8 @@ class _HubKernelConfig: repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None, - wrapped_forward_attr="_wrapped_flash_attn_forward", - wrapped_backward_attr="_wrapped_flash_attn_backward", + wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", + wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", ), AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None @@ -602,27 +602,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): # ===== Helpers for downloading kernels ===== +def _resolve_kernel_attr(module, attr_path: str): + target = module + for attr in attr_path.split("."): + if not hasattr(target, attr): + raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.") + target = getattr(target, attr) + return target + + def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None: if backend not in _HUB_KERNELS_REGISTRY: return config = _HUB_KERNELS_REGISTRY[backend] - if config.kernel_fn is not None: + needs_kernel = config.kernel_fn is None + needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None + needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None + + if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward): return try: from kernels import get_kernel kernel_module = get_kernel(config.repo_id, revision=config.revision) - kernel_func = getattr(kernel_module, config.function_attr) - # Cache the downloaded kernel function in the config object - config.kernel_fn = kernel_func + if needs_kernel: + config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr) - if config.wrapped_forward_attr is not None and config.wrapped_forward_attr is not None: - wrapped_forward_fn = getattr(kernel_module, config.wrapped_forward_attr) - wrapped_backward_fn = getattr(kernel_module, config.wrapped_backward_attr) - config.wrapped_forward_fn = wrapped_forward_fn - config.wrapped_backward_fn = wrapped_backward_fn + if needs_wrapped_forward: + config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr) + + if needs_wrapped_backward: + config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr) except Exception as e: logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}") @@ -1307,6 +1319,7 @@ def _sage_attention_hub_forward_op( return (out, lse) if return_lse else out + # ===== Context parallel ===== From f732ff114403de85658d0f699d9d4b3bd0a32510 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 15:30:33 +0530 Subject: [PATCH 3/3] up --- src/diffusers/models/attention_dispatch.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 433815d7ed9b..96920c8631ec 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -994,21 +994,6 @@ def _flash_attention_backward_op( return grad_query, grad_key, grad_value -def _maybe_format_lse_for_context_parallel( - lse: Optional[torch.Tensor], - *, - seq_len: int, - num_heads: int, -) -> Optional[torch.Tensor]: - if lse is None or lse.ndim != 3: - return lse - - if lse.shape[1] == num_heads and lse.shape[2] == seq_len: - lse = lse.permute(0, 2, 1) - - return lse.contiguous() - - def _flash_attention_hub_forward_op( ctx: torch.autograd.function.FunctionCtx, query: torch.Tensor, @@ -1063,7 +1048,7 @@ def _flash_attention_hub_forward_op( alibi_slopes, return_lse, ) - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() if _save_ctx: ctx.save_for_backward(query, key, value, out, lse, rng_state) @@ -1173,7 +1158,7 @@ def _flash_attention_3_hub_forward_op( lse = None if return_lse: out, lse = out - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() if _save_ctx: ctx.save_for_backward(query, key, value) @@ -1315,7 +1300,7 @@ def _sage_attention_hub_forward_op( lse = None if return_lse: out, lse, *_ = out - lse = _maybe_format_lse_for_context_parallel(lse, seq_len=query.shape[1], num_heads=query.shape[2]) + lse = lse.permute(0, 2, 1).contiguous() return (out, lse) if return_lse else out