|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import inspect |
| 15 | +import math |
15 | 16 | from importlib import import_module |
16 | 17 | from typing import Callable, List, Optional, Union |
17 | 18 |
|
|
21 | 22 |
|
22 | 23 | from ..image_processor import IPAdapterMaskProcessor |
23 | 24 | from ..utils import deprecate, logging |
24 | | -from ..utils.import_utils import is_xformers_available |
| 25 | +from ..utils.import_utils import is_torch_npu_available, is_xformers_available |
25 | 26 | from ..utils.torch_utils import maybe_allow_in_graph |
26 | 27 | from .lora import LoRALinearLayer |
27 | 28 |
|
28 | 29 |
|
29 | 30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
30 | 31 |
|
| 32 | +if is_torch_npu_available(): |
| 33 | + import torch_npu |
31 | 34 |
|
32 | 35 | if is_xformers_available(): |
33 | 36 | import xformers |
@@ -209,6 +212,23 @@ def __init__( |
209 | 212 | ) |
210 | 213 | self.set_processor(processor) |
211 | 214 |
|
| 215 | + def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None: |
| 216 | + r""" |
| 217 | + Set whether to use npu flash attention from `torch_npu` or not. |
| 218 | +
|
| 219 | + """ |
| 220 | + if use_npu_flash_attention: |
| 221 | + processor = AttnProcessorNPU() |
| 222 | + else: |
| 223 | + # set attention processor |
| 224 | + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses |
| 225 | + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention |
| 226 | + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 |
| 227 | + processor = ( |
| 228 | + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() |
| 229 | + ) |
| 230 | + self.set_processor(processor) |
| 231 | + |
212 | 232 | def set_use_memory_efficient_attention_xformers( |
213 | 233 | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None |
214 | 234 | ) -> None: |
@@ -1207,6 +1227,116 @@ def __call__( |
1207 | 1227 | return hidden_states |
1208 | 1228 |
|
1209 | 1229 |
|
| 1230 | +class AttnProcessorNPU: |
| 1231 | + |
| 1232 | + r""" |
| 1233 | + Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If |
| 1234 | + fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is |
| 1235 | + not significant. |
| 1236 | +
|
| 1237 | + """ |
| 1238 | + |
| 1239 | + def __init__(self): |
| 1240 | + if not is_torch_npu_available(): |
| 1241 | + raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") |
| 1242 | + |
| 1243 | + def __call__( |
| 1244 | + self, |
| 1245 | + attn: Attention, |
| 1246 | + hidden_states: torch.FloatTensor, |
| 1247 | + encoder_hidden_states: Optional[torch.FloatTensor] = None, |
| 1248 | + attention_mask: Optional[torch.FloatTensor] = None, |
| 1249 | + temb: Optional[torch.FloatTensor] = None, |
| 1250 | + *args, |
| 1251 | + **kwargs, |
| 1252 | + ) -> torch.FloatTensor: |
| 1253 | + if len(args) > 0 or kwargs.get("scale", None) is not None: |
| 1254 | + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
| 1255 | + deprecate("scale", "1.0.0", deprecation_message) |
| 1256 | + |
| 1257 | + residual = hidden_states |
| 1258 | + if attn.spatial_norm is not None: |
| 1259 | + hidden_states = attn.spatial_norm(hidden_states, temb) |
| 1260 | + |
| 1261 | + input_ndim = hidden_states.ndim |
| 1262 | + |
| 1263 | + if input_ndim == 4: |
| 1264 | + batch_size, channel, height, width = hidden_states.shape |
| 1265 | + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
| 1266 | + |
| 1267 | + batch_size, sequence_length, _ = ( |
| 1268 | + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
| 1269 | + ) |
| 1270 | + |
| 1271 | + if attention_mask is not None: |
| 1272 | + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| 1273 | + # scaled_dot_product_attention expects attention_mask shape to be |
| 1274 | + # (batch, heads, source_length, target_length) |
| 1275 | + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
| 1276 | + |
| 1277 | + if attn.group_norm is not None: |
| 1278 | + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
| 1279 | + |
| 1280 | + query = attn.to_q(hidden_states) |
| 1281 | + |
| 1282 | + if encoder_hidden_states is None: |
| 1283 | + encoder_hidden_states = hidden_states |
| 1284 | + elif attn.norm_cross: |
| 1285 | + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
| 1286 | + |
| 1287 | + key = attn.to_k(encoder_hidden_states) |
| 1288 | + value = attn.to_v(encoder_hidden_states) |
| 1289 | + |
| 1290 | + inner_dim = key.shape[-1] |
| 1291 | + head_dim = inner_dim // attn.heads |
| 1292 | + |
| 1293 | + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 1294 | + |
| 1295 | + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 1296 | + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
| 1297 | + |
| 1298 | + # the output of sdp = (batch, num_heads, seq_len, head_dim) |
| 1299 | + if query.dtype in (torch.float16, torch.bfloat16): |
| 1300 | + hidden_states = torch_npu.npu_fusion_attention( |
| 1301 | + query, |
| 1302 | + key, |
| 1303 | + value, |
| 1304 | + attn.heads, |
| 1305 | + input_layout="BNSD", |
| 1306 | + pse=None, |
| 1307 | + atten_mask=attention_mask, |
| 1308 | + scale=1.0 / math.sqrt(query.shape[-1]), |
| 1309 | + pre_tockens=65536, |
| 1310 | + next_tockens=65536, |
| 1311 | + keep_prob=1.0, |
| 1312 | + sync=False, |
| 1313 | + inner_precise=0, |
| 1314 | + )[0] |
| 1315 | + else: |
| 1316 | + # TODO: add support for attn.scale when we move to Torch 2.1 |
| 1317 | + hidden_states = F.scaled_dot_product_attention( |
| 1318 | + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
| 1319 | + ) |
| 1320 | + |
| 1321 | + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
| 1322 | + hidden_states = hidden_states.to(query.dtype) |
| 1323 | + |
| 1324 | + # linear proj |
| 1325 | + hidden_states = attn.to_out[0](hidden_states) |
| 1326 | + # dropout |
| 1327 | + hidden_states = attn.to_out[1](hidden_states) |
| 1328 | + |
| 1329 | + if input_ndim == 4: |
| 1330 | + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
| 1331 | + |
| 1332 | + if attn.residual_connection: |
| 1333 | + hidden_states = hidden_states + residual |
| 1334 | + |
| 1335 | + hidden_states = hidden_states / attn.rescale_output_factor |
| 1336 | + |
| 1337 | + return hidden_states |
| 1338 | + |
| 1339 | + |
1210 | 1340 | class AttnProcessor2_0: |
1211 | 1341 | r""" |
1212 | 1342 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
0 commit comments