diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 8894b1447..003bf22a4 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -20,7 +20,7 @@ from paddle import einsum, nn from ..utils import USE_PEFT_BACKEND, deprecate, logging -from ..utils.import_utils import is_ppxformers_available +from ..utils.import_utils import is_npu_available, is_ppxformers_available from ..utils.paddle_utils import maybe_allow_in_graph from .lora import LoRACompatibleLinear, LoRALinearLayer @@ -111,7 +111,7 @@ def __init__( super().__init__() # To prevent circular import. - from .normalization import RMSNorm, FP32LayerNorm, LpNorm + from .normalization import FP32LayerNorm, LpNorm, RMSNorm self.inner_dim = dim_head * heads self.inner_dim = out_dim if out_dim is not None else dim_head * heads @@ -251,7 +251,7 @@ def __init__( # We use the AttnProcessor2_5 by default when paddle 2.5 is used which uses # paddle.nn.functional.scaled_dot_product_attention_ for native Flash/memory_efficient_attention if processor is None: - processor = AttnProcessor2_5() if is_ppxformers_available() else AttnProcessor() + processor = AttnProcessor2_5() if is_ppxformers_available() or is_npu_available() else AttnProcessor() self.set_processor(processor) @property @@ -997,12 +997,20 @@ def __call__( encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.reshape([batch_size, -1, attn.heads, head_dim]) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.reshape([batch_size, -1, attn.heads, head_dim]) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.reshape([batch_size, -1, attn.heads, head_dim]) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.reshape( + [batch_size, -1, attn.heads, head_dim] + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.reshape( + [batch_size, -1, attn.heads, head_dim] + ) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.reshape( + [batch_size, -1, attn.heads, head_dim] + ) if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj, begin_norm_axis=3) + encoder_hidden_states_query_proj = attn.norm_added_q( + encoder_hidden_states_query_proj, begin_norm_axis=3 + ) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj, begin_norm_axis=3) diff --git a/ppdiffusers/ppdiffusers/patches/paddle_patch.py b/ppdiffusers/ppdiffusers/patches/paddle_patch.py index b5a507f93..83634ad30 100644 --- a/ppdiffusers/ppdiffusers/patches/paddle_patch.py +++ b/ppdiffusers/ppdiffusers/patches/paddle_patch.py @@ -351,27 +351,30 @@ def to(self=None, device=None, dtype=None, blocking=None): nn.Layer.to = to -from ..utils.import_utils import is_ppxformers_available, is_npu_available +from ..utils.import_utils import is_npu_available, is_ppxformers_available if is_npu_available(): for lib in os.listdir(os.getenv("CUSTOM_DEVICE_ROOT")): if lib.endswith(".so"): - paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op( - lib - ) + paddle.utils.cpp_extension.extension_utils.load_op_meta_info_and_register_op(lib) from paddle.base import core - def scaled_dot_product_attention_npu(query, - key, - value, - attn_mask=None, - dropout_p=0.0, - is_causal=False, - training=True, - name=None, - fixed_seed_offset=None, - return_softmax=False, - is_triangle_upper_mask=True, - ): + + def scaled_dot_product_attention_npu( + query, + key, + value, + attn_mask=None, + actual_seq_q_len=None, + actual_seq_kv_len=None, + dropout_p=0.0, + is_causal=False, + training=True, + name=None, + fixed_seed_offset=None, + return_softmax=False, + is_triangle_upper_mask=True, + is_varlen=False, + ): out = core.eager._run_custom_op( "flash_attention_npu", query, @@ -379,13 +382,17 @@ def scaled_dot_product_attention_npu(query, value, fixed_seed_offset, attn_mask, + actual_seq_q_len, + actual_seq_kv_len, dropout_p, is_causal, return_softmax, not training, is_triangle_upper_mask, + is_varlen, )[0] return out + paddle.nn.functional.scaled_dot_product_attention_npu = scaled_dot_product_attention_npu if is_ppxformers_available() or is_npu_available(): @@ -407,8 +414,9 @@ def scaled_dot_product_attention_npu(query, paddle.ones((1, 1, 2, 40), dtype=paddle.float16), attn_mask=paddle.ones((1, 2, 1, 1), dtype=paddle.float16), ) - + from paddle.nn.functional.flash_attention import flash_attention + _ = flash_attention( paddle.ones((1, 1, 2, 40), dtype=paddle.float16), paddle.ones((1, 1, 2, 40), dtype=paddle.float16),