From c055439bb779a8efb2cfb9c5061fafc391050ef1 Mon Sep 17 00:00:00 2001 From: warrentdrew Date: Fri, 11 Apr 2025 16:39:03 +0800 Subject: [PATCH 1/2] fix flash attn --- .../ppdiffusers/models/attention_processor.py | 19 ++++++--- ppdiffusers/ppdiffusers/utils/import_utils.py | 39 ++++++++++++++----- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 86192f24d..5841f86ec 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -21,7 +21,11 @@ from paddle import einsum, nn from ..utils import USE_PEFT_BACKEND, deprecate, logging -from ..utils.import_utils import is_ppxformers_available +from ..utils.import_utils import ( + is_flash_attention_available, + is_npu_available, + is_ppxformers_available, +) from ..utils.paddle_utils import maybe_allow_in_graph from .lora import LoRACompatibleLinear, LoRALinearLayer @@ -256,7 +260,8 @@ def __init__( # We use the AttnProcessor2_5 by default when paddle 2.5 is used which uses # paddle.nn.functional.scaled_dot_product_attention_ for native Flash/memory_efficient_attention if processor is None: - processor = AttnProcessor2_5() if is_ppxformers_available() else AttnProcessor() + processor = AttnProcessor2_5() if is_flash_attention_available() else AttnProcessor() + processor = AttnProcessor() if is_npu_available() else processor self.set_processor(processor) @property @@ -373,7 +378,8 @@ def set_use_memory_efficient_attention_xformers( # set attention processor # We use the AttnProcessor2_5 by default when paddle 2.5 is used which uses # paddle.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = AttnProcessor2_5() if is_ppxformers_available() else AttnProcessor() + processor = AttnProcessor2_5() if is_flash_attention_available() else AttnProcessor() + processor = AttnProcessor() if is_npu_available() else processor self.set_processor(processor) @@ -398,7 +404,8 @@ def set_attention_slice(self, slice_size: int) -> None: # set attention processor # We use the AttnProcessor2_5 by default when paddle 2.5 is used which uses # paddle.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = AttnProcessor2_5() if is_ppxformers_available() else AttnProcessor() + processor = AttnProcessor2_5() if is_flash_attention_available() else AttnProcessor() + processor = AttnProcessor() if is_npu_available() else processor self.set_processor(processor) @@ -665,7 +672,7 @@ def prepare_attention_mask( num_heads = self.heads if attention_mask is None: return attention_mask - + ori_type = attention_mask.dtype attention_mask = attention_mask.to(paddle.float32) @@ -1296,7 +1303,7 @@ def __call__( # adapt the scaled_dot_product_attention_ when attention_mask is a bool tensor if attention_mask is not None and attention_mask.dtype == paddle.bool: L, S = query.shape[1], key.shape[1] - attention_mask_tmp = paddle.zeros([1,1, L, S], dtype=query.dtype) + attention_mask_tmp = paddle.zeros([1, 1, L, S], dtype=query.dtype) attention_mask_tmp = attention_mask_tmp.masked_fill(attention_mask.logical_not(), float("-inf")) attention_mask = attention_mask_tmp diff --git a/ppdiffusers/ppdiffusers/utils/import_utils.py b/ppdiffusers/ppdiffusers/utils/import_utils.py index d701ad6bd..93b55c9bf 100644 --- a/ppdiffusers/ppdiffusers/utils/import_utils.py +++ b/ppdiffusers/ppdiffusers/utils/import_utils.py @@ -65,6 +65,7 @@ def str2bool(v): if USE_PADDLE in ENV_VARS_TRUE_AND_AUTO_VALUES: _paddle_available = importlib.util.find_spec("paddle") is not None _ppxformers_available = False + _flash_attention_available = False if _paddle_available: try: import paddle @@ -75,24 +76,36 @@ def str2bool(v): _paddle_available = False if _paddle_available: + # try: + # from paddle.incubate.nn.memory_efficient_attention import ( # noqa + # memory_efficient_attention, + # ) + + # # _ = memory_efficient_attention( + # # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + # # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + # # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + # # ) + # _ppxformers_available = True + # except Exception: + # _ppxformers_available = False + try: - from paddle.incubate.nn.memory_efficient_attention import ( # noqa - memory_efficient_attention, + _ = paddle.nn.functional.scaled_dot_product_attention( + paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + attn_mask=paddle.ones((1, 2, 1, 1), dtype=paddle.float16), ) - - # _ = memory_efficient_attention( - # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # ) - _ppxformers_available = True + _flash_attention_available = True except Exception: - _ppxformers_available = False + _flash_attention_available = False else: logger.info("Disabling Paddle because USE_PADDLE is set") _paddle_available = False _ppxformers_available = False + _flash_attention_available = False _torch_version = "N/A" _torch_available = importlib.util.find_spec("torch") is not None @@ -375,9 +388,11 @@ def is_scipy_available(): def is_librosa_available(): return _librosa_available + def is_npu_available(): return paddle.device.get_device().startswith("npu") + def is_ppxformers_available(): USE_PPXFORMERS = str2bool(os.getenv("USE_PPXFORMERS", True)) if USE_PPXFORMERS: @@ -386,6 +401,10 @@ def is_ppxformers_available(): return False +def is_flash_attention_available(): + return _flash_attention_available + + # NOTE this is paddle accelerate def is_accelerate_available(): return _accelerate_available From 51f2240d3a51f9f6beec31dbb71fb20f00da92f8 Mon Sep 17 00:00:00 2001 From: warrentdrew Date: Fri, 11 Apr 2025 17:17:53 +0800 Subject: [PATCH 2/2] fix flash attn --- .../ppdiffusers/models/attention_processor.py | 15 +++----- ppdiffusers/ppdiffusers/utils/import_utils.py | 35 +++++-------------- 2 files changed, 13 insertions(+), 37 deletions(-) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 5841f86ec..ac3218f33 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -21,11 +21,7 @@ from paddle import einsum, nn from ..utils import USE_PEFT_BACKEND, deprecate, logging -from ..utils.import_utils import ( - is_flash_attention_available, - is_npu_available, - is_ppxformers_available, -) +from ..utils.import_utils import is_ppxformers_available from ..utils.paddle_utils import maybe_allow_in_graph from .lora import LoRACompatibleLinear, LoRALinearLayer @@ -260,8 +256,7 @@ def __init__( # We use the AttnProcessor2_5 by default when paddle 2.5 is used which uses # paddle.nn.functional.scaled_dot_product_attention_ for native Flash/memory_efficient_attention if processor is None: - processor = AttnProcessor2_5() if is_flash_attention_available() else AttnProcessor() - processor = AttnProcessor() if is_npu_available() else processor + processor = AttnProcessor2_5() if is_ppxformers_available() else AttnProcessor() self.set_processor(processor) @property @@ -378,8 +373,7 @@ def set_use_memory_efficient_attention_xformers( # set attention processor # We use the AttnProcessor2_5 by default when paddle 2.5 is used which uses # paddle.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = AttnProcessor2_5() if is_flash_attention_available() else AttnProcessor() - processor = AttnProcessor() if is_npu_available() else processor + processor = AttnProcessor2_5() if is_ppxformers_available() else AttnProcessor() self.set_processor(processor) @@ -404,8 +398,7 @@ def set_attention_slice(self, slice_size: int) -> None: # set attention processor # We use the AttnProcessor2_5 by default when paddle 2.5 is used which uses # paddle.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention - processor = AttnProcessor2_5() if is_flash_attention_available() else AttnProcessor() - processor = AttnProcessor() if is_npu_available() else processor + processor = AttnProcessor2_5() if is_ppxformers_available() else AttnProcessor() self.set_processor(processor) diff --git a/ppdiffusers/ppdiffusers/utils/import_utils.py b/ppdiffusers/ppdiffusers/utils/import_utils.py index 93b55c9bf..410956aa2 100644 --- a/ppdiffusers/ppdiffusers/utils/import_utils.py +++ b/ppdiffusers/ppdiffusers/utils/import_utils.py @@ -42,6 +42,10 @@ def str2bool(v): raise ValueError("Not supported value: {}".format(v)) +def is_npu_available(): + return paddle.device.get_device().startswith("npu") + + # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata @@ -65,7 +69,6 @@ def str2bool(v): if USE_PADDLE in ENV_VARS_TRUE_AND_AUTO_VALUES: _paddle_available = importlib.util.find_spec("paddle") is not None _ppxformers_available = False - _flash_attention_available = False if _paddle_available: try: import paddle @@ -76,20 +79,6 @@ def str2bool(v): _paddle_available = False if _paddle_available: - # try: - # from paddle.incubate.nn.memory_efficient_attention import ( # noqa - # memory_efficient_attention, - # ) - - # # _ = memory_efficient_attention( - # # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # # ) - # _ppxformers_available = True - # except Exception: - # _ppxformers_available = False - try: _ = paddle.nn.functional.scaled_dot_product_attention( paddle.ones((1, 1, 2, 40), dtype=paddle.float16), @@ -97,15 +86,17 @@ def str2bool(v): paddle.ones((1, 1, 2, 40), dtype=paddle.float16), attn_mask=paddle.ones((1, 2, 1, 1), dtype=paddle.float16), ) - _flash_attention_available = True + _ppxformers_available = True except Exception: - _flash_attention_available = False + _ppxformers_available = False + + if is_npu_available(): + _ppxformers_available = False else: logger.info("Disabling Paddle because USE_PADDLE is set") _paddle_available = False _ppxformers_available = False - _flash_attention_available = False _torch_version = "N/A" _torch_available = importlib.util.find_spec("torch") is not None @@ -389,10 +380,6 @@ def is_librosa_available(): return _librosa_available -def is_npu_available(): - return paddle.device.get_device().startswith("npu") - - def is_ppxformers_available(): USE_PPXFORMERS = str2bool(os.getenv("USE_PPXFORMERS", True)) if USE_PPXFORMERS: @@ -401,10 +388,6 @@ def is_ppxformers_available(): return False -def is_flash_attention_available(): - return _flash_attention_available - - # NOTE this is paddle accelerate def is_accelerate_available(): return _accelerate_available