diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index 86192f24d..ac3218f33 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -665,7 +665,7 @@ def prepare_attention_mask( num_heads = self.heads if attention_mask is None: return attention_mask - + ori_type = attention_mask.dtype attention_mask = attention_mask.to(paddle.float32) @@ -1296,7 +1296,7 @@ def __call__( # adapt the scaled_dot_product_attention_ when attention_mask is a bool tensor if attention_mask is not None and attention_mask.dtype == paddle.bool: L, S = query.shape[1], key.shape[1] - attention_mask_tmp = paddle.zeros([1,1, L, S], dtype=query.dtype) + attention_mask_tmp = paddle.zeros([1, 1, L, S], dtype=query.dtype) attention_mask_tmp = attention_mask_tmp.masked_fill(attention_mask.logical_not(), float("-inf")) attention_mask = attention_mask_tmp diff --git a/ppdiffusers/ppdiffusers/utils/import_utils.py b/ppdiffusers/ppdiffusers/utils/import_utils.py index d701ad6bd..410956aa2 100644 --- a/ppdiffusers/ppdiffusers/utils/import_utils.py +++ b/ppdiffusers/ppdiffusers/utils/import_utils.py @@ -42,6 +42,10 @@ def str2bool(v): raise ValueError("Not supported value: {}".format(v)) +def is_npu_available(): + return paddle.device.get_device().startswith("npu") + + # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata @@ -76,19 +80,19 @@ def str2bool(v): if _paddle_available: try: - from paddle.incubate.nn.memory_efficient_attention import ( # noqa - memory_efficient_attention, + _ = paddle.nn.functional.scaled_dot_product_attention( + paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + paddle.ones((1, 1, 2, 40), dtype=paddle.float16), + attn_mask=paddle.ones((1, 2, 1, 1), dtype=paddle.float16), ) - - # _ = memory_efficient_attention( - # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # paddle.ones((1, 1, 2, 40), dtype=paddle.float16), - # ) _ppxformers_available = True except Exception: _ppxformers_available = False + if is_npu_available(): + _ppxformers_available = False + else: logger.info("Disabling Paddle because USE_PADDLE is set") _paddle_available = False @@ -375,8 +379,6 @@ def is_scipy_available(): def is_librosa_available(): return _librosa_available -def is_npu_available(): - return paddle.device.get_device().startswith("npu") def is_ppxformers_available(): USE_PPXFORMERS = str2bool(os.getenv("USE_PPXFORMERS", True))