diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02d17e8e9534..5390eccf784a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import inspect from typing import Callable, Dict, Optional, Tuple, Union import torch diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c25e9997e3fb..1c6fc06fc5e8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1244,31 +1244,21 @@ class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() + + from .transformers.transformer_flux import FluxPosEmbed as FluxPosEmbed_ + + deprecate( + "FluxPosEmbed", + "1.0.0", + "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please use `FluxPosEmbed` from `diffusers.models.transformers.transformer_flux` instead.", + ) + self.theta = theta self.axes_dim = axes_dim + self._rope = FluxPosEmbed_(theta, axes_dim) def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - cos_out = [] - sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], - pos[:, i], - theta=self.theta, - repeat_interleave_real=True, - use_real=True, - freqs_dtype=freqs_dtype, - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin + return self._rope(ids) class TimestepEmbedding(nn.Module): diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..fa5b09637835 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -599,6 +599,50 @@ def enable_group_offload( low_cpu_mem_usage=low_cpu_mem_usage, ) + def set_attention_backend(self, backend: str) -> None: + """ + Set the attention backend for the model. + + Args: + backend (`str`): + The name of the backend to set. Must be one of the available backends defined in + `AttentionBackendName`. Available backends can be found in + `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product + attention as backend. + """ + from .attention import AttentionModuleMixin + from .attention_dispatch import AttentionBackendName + + backend = backend.lower() + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + backend = AttentionBackendName(backend) + + for module in self.modules(): + if not isinstance(module, AttentionModuleMixin): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = backend + + def reset_attention_backend(self) -> None: + """ + Resets the attention backend for the model. Following calls to `forward` will use the environment default or + the torch native scaled dot product attention. + """ + from .attention_processor import Attention, MochiAttention + + attention_classes = (Attention, MochiAttention) + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = None + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index d5bfacbb2c0b..2426fb3f0995 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -24,10 +24,15 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import Attention, AttentionMixin, FeedForward +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + apply_rotary_emb, + get_1d_rotary_pos_embed, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle @@ -73,7 +78,6 @@ def get_qkv_projections(self, attn, hidden_states, encoder_hidden_states=None): """Public method to get projections based on whether we're using fused mode or not.""" if attn.is_fused and hasattr(attn, "to_qkv"): return self._get_fused_projections(attn, hidden_states, encoder_hidden_states) - return self._get_projections(attn, hidden_states, encoder_hidden_states) def __call__( @@ -117,17 +121,10 @@ def __call__( value = torch.cat([encoder_value, value], dim=2) if image_rotary_emb is not None: - from ..embeddings import apply_rotary_emb - query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attention_mask, - ) + hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -242,12 +239,10 @@ def __call__( value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -292,13 +287,76 @@ def __call__( @maybe_allow_in_graph -class FluxAttention(Attention): +class FluxAttention(torch.nn.Module, AttentionModuleMixin): _default_processor_cls = FluxAttnProcessor _available_processors = [ FluxAttnProcessor, FluxIPAdapterAttnProcessor, ] + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + qk_norm: Optional[str] = None, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + elementwise_affine: bool = True, + processor=None, + ): + super().__init__() + assert qk_norm == "rms_norm", "Flux uses RMSNorm" + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_proj_bias = added_proj_bias + + self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) + self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + + if added_kv_proj_dim is not None: + self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) + self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) + self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) + @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): @@ -330,20 +388,19 @@ def forward( image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: - residual = hidden_states + joint_attention_kwargs = joint_attention_kwargs or {} + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) - joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) + attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + proj_out = self.proj_out(attn_mlp_hidden_states) + hidden_states = hidden_states + gate.unsqueeze(1) * proj_out - hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) - gate = gate.unsqueeze(1) - hidden_states = gate * self.proj_out(hidden_states) - hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) @@ -386,12 +443,13 @@ def forward( image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + joint_attention_kwargs = joint_attention_kwargs or {} + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) - joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. attention_outputs = self.attn( hidden_states=norm_hidden_states, @@ -410,7 +468,7 @@ def forward( hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -420,21 +478,54 @@ def forward( hidden_states = hidden_states + ip_attn_output # Process attention outputs for the `encoder_hidden_states`. - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_encoder_hidden_states = norm_encoder_hidden_states * ( + 1 + c_scale_mlp.unsqueeze(1) + ) + c_shift_mlp.unsqueeze(1) context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + class FluxTransformer2DModel( ModelMixin, ConfigMixin, @@ -537,10 +628,6 @@ def __init__( self.gradient_checkpointing = False - # Using inherited methods from AttentionMixin - - # Using inherited methods from AttentionMixin - def forward( self, hidden_states: torch.Tensor, @@ -634,11 +721,7 @@ def forward( for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb ) else: @@ -665,12 +748,7 @@ def forward( for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - temb, - image_rotary_emb, - ) + hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb, image_rotary_emb) else: hidden_states = block( diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2df05cb8eb36..cadcedb98a14 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,9 @@ is_bitsandbytes_version, is_bs4_available, is_cosmos_guardrail_available, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, is_flax_available, is_ftfy_available, is_gguf_available, @@ -90,6 +93,8 @@ is_peft_version, is_pytorch_retinaface_available, is_safetensors_available, + is_sageattention_available, + is_sageattention_version, is_scipy_available, is_sentencepiece_available, is_tensorboard_available, @@ -108,6 +113,7 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_xformers_version, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 7c04287d33ed..f8f04cc03abd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -41,6 +41,8 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 +DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f7244e97b878..4fe71801e8f9 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -219,6 +219,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _nltk_available, _nltk_version = _is_package_available("nltk") _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") +_sageattention_available, _sageattention_version = _is_package_available("sageattention") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") +_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") def is_torch_available(): @@ -377,6 +380,18 @@ def is_hpu_available(): return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch")) +def is_sageattention_available(): + return _sageattention_available + + +def is_flash_attn_available(): + return _flash_attn_available + + +def is_flash_attn_3_available(): + return _flash_attn_3_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -803,6 +818,51 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_xformers_version(operation: str, version: str): + """ + Compares the current xformers version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _xformers_available: + return False + return compare_versions(parse(_xformers_version), operation, version) + + +def is_sageattention_version(operation: str, version: str): + """ + Compares the current sageattention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _sageattention_available: + return False + return compare_versions(parse(_sageattention_version), operation, version) + + +def is_flash_attn_version(operation: str, version: str): + """ + Compares the current flash-attention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _flash_attn_available: + return False + return compare_versions(parse(_flash_attn_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects