|  | 
| 29 | 29 |     FluxAttnProcessor2_0, | 
| 30 | 30 |     FluxAttnProcessor2_0_NPU, | 
| 31 | 31 |     FusedFluxAttnProcessor2_0, | 
|  | 32 | +    FusedFluxAttnProcessor2_0_NPU, | 
| 32 | 33 | ) | 
| 33 | 34 | from ...models.modeling_utils import ModelMixin | 
| 34 | 35 | from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle | 
| @@ -140,7 +141,10 @@ def __init__( | 
| 140 | 141 |         self.norm1_context = AdaLayerNormZero(dim) | 
| 141 | 142 | 
 | 
| 142 | 143 |         if hasattr(F, "scaled_dot_product_attention"): | 
| 143 |  | -            processor = FluxAttnProcessor2_0() | 
|  | 144 | +            if is_torch_npu_available(): | 
|  | 145 | +                processor = FluxAttnProcessor2_0_NPU() | 
|  | 146 | +            else: | 
|  | 147 | +                processor = FluxAttnProcessor2_0() | 
| 144 | 148 |         else: | 
| 145 | 149 |             raise ValueError( | 
| 146 | 150 |                 "The current PyTorch version does not support the `scaled_dot_product_attention` function." | 
| @@ -405,7 +409,10 @@ def fuse_qkv_projections(self): | 
| 405 | 409 |             if isinstance(module, Attention): | 
| 406 | 410 |                 module.fuse_projections(fuse=True) | 
| 407 | 411 | 
 | 
| 408 |  | -        self.set_attn_processor(FusedFluxAttnProcessor2_0()) | 
|  | 412 | +        if is_torch_npu_available(): | 
|  | 413 | +            self.set_attn_processor(FusedFluxAttnProcessor2_0_NPU()) | 
|  | 414 | +        else: | 
|  | 415 | +            self.set_attn_processor(FusedFluxAttnProcessor2_0()) | 
| 409 | 416 | 
 | 
| 410 | 417 |     # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections | 
| 411 | 418 |     def unfuse_qkv_projections(self): | 
|  | 
0 commit comments