diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index d23d05f7e38b..bd1c29009976 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -57,6 +57,7 @@ is_wandb_available, ) from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -68,6 +69,12 @@ logger = get_logger(__name__) +if is_torch_npu_available(): + import torch_npu + + torch.npu.config.allow_internal_format = False + torch.npu.set_compile_mode(jit_compile=False) + def save_model_card( repo_id: str, @@ -189,6 +196,8 @@ def log_validation( del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() return images @@ -1035,7 +1044,9 @@ def main(args): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + has_supported_fp16_accelerator = ( + torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available() + ) torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 if args.prior_generation_precision == "fp32": torch_dtype = torch.float32 @@ -1073,6 +1084,8 @@ def main(args): del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() # Handle the repository creation if accelerator.is_main_process: @@ -1354,6 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't @@ -1719,7 +1734,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) if not args.train_text_encoder: del text_encoder_one, text_encoder_two - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif is_torch_npu_available(): + torch_npu.npu.empty_cache() gc.collect() # Save the lora layers diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index db88ecbbb9d3..20c5cf3d925e 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1893,6 +1893,112 @@ def __call__( return hidden_states +class FluxAttnProcessor2_0_NPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class FusedFluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" @@ -1987,6 +2093,117 @@ def __call__( return hidden_states +class FusedFluxAttnProcessor2_0_NPU: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = torch.split(qkv, split_size, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + # `context` projections. + if encoder_hidden_states is not None: + encoder_qkv = attn.to_added_qkv(encoder_hidden_states) + split_size = encoder_qkv.shape[-1] // 3 + ( + encoder_hidden_states_query_proj, + encoder_hidden_states_key_proj, + encoder_hidden_states_value_proj, + ) = torch.split(encoder_qkv, split_size, dim=-1) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if query.dtype in (torch.float16, torch.bfloat16): + hidden_states = torch_npu.npu_fusion_attention( + query, + key, + value, + attn.heads, + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1.0, + sync=False, + inner_precise=0, + )[0] + else: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 5d39a1bb5391..f078cace0f3e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -27,11 +27,13 @@ Attention, AttentionProcessor, FluxAttnProcessor2_0, + FluxAttnProcessor2_0_NPU, FusedFluxAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from ..modeling_outputs import Transformer2DModelOutput @@ -64,7 +66,10 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - processor = FluxAttnProcessor2_0() + if is_torch_npu_available(): + processor = FluxAttnProcessor2_0_NPU() + else: + processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None,