|  | 
| 20 | 20 | import torch.nn as nn | 
| 21 | 21 | import torch.nn.functional as F | 
| 22 | 22 | 
 | 
| 23 |  | -from ..utils import is_torch_version | 
|  | 23 | +from ..utils import is_torch_version, is_torch_npu_available | 
| 24 | 24 | from .activations import get_activation | 
| 25 | 25 | from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings | 
| 26 | 26 | 
 | 
| @@ -505,19 +505,30 @@ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool | 
| 505 | 505 |                 self.bias = nn.Parameter(torch.zeros(dim)) | 
| 506 | 506 | 
 | 
| 507 | 507 |     def forward(self, hidden_states): | 
| 508 |  | -        input_dtype = hidden_states.dtype | 
| 509 |  | -        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | 
| 510 |  | -        hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | 
| 511 |  | - | 
| 512 |  | -        if self.weight is not None: | 
| 513 |  | -            # convert into half-precision if necessary | 
| 514 |  | -            if self.weight.dtype in [torch.float16, torch.bfloat16]: | 
| 515 |  | -                hidden_states = hidden_states.to(self.weight.dtype) | 
| 516 |  | -            hidden_states = hidden_states * self.weight | 
|  | 508 | +        if is_torch_npu_available(): | 
|  | 509 | +            import torch_npu | 
|  | 510 | + | 
|  | 511 | +            if self.weight is not None: | 
|  | 512 | +                # convert into half-precision if necessary | 
|  | 513 | +                if self.weight.dtype in [torch.float16, torch.bfloat16]: | 
|  | 514 | +                    hidden_states = hidden_states.to(self.weight.dtype) | 
|  | 515 | +            hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] | 
| 517 | 516 |             if self.bias is not None: | 
| 518 | 517 |                 hidden_states = hidden_states + self.bias | 
| 519 | 518 |         else: | 
| 520 |  | -            hidden_states = hidden_states.to(input_dtype) | 
|  | 519 | +            input_dtype = hidden_states.dtype | 
|  | 520 | +            variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | 
|  | 521 | +            hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | 
|  | 522 | + | 
|  | 523 | +            if self.weight is not None: | 
|  | 524 | +                # convert into half-precision if necessary | 
|  | 525 | +                if self.weight.dtype in [torch.float16, torch.bfloat16]: | 
|  | 526 | +                    hidden_states = hidden_states.to(self.weight.dtype) | 
|  | 527 | +                hidden_states = hidden_states * self.weight | 
|  | 528 | +                if self.bias is not None: | 
|  | 529 | +                    hidden_states = hidden_states + self.bias | 
|  | 530 | +            else: | 
|  | 531 | +                hidden_states = hidden_states.to(input_dtype) | 
| 521 | 532 | 
 | 
| 522 | 533 |         return hidden_states | 
| 523 | 534 | 
 | 
|  | 
0 commit comments