|  | 
| 13 | 13 | # See the License for the specific language governing permissions and | 
| 14 | 14 | # limitations under the License. | 
| 15 | 15 | 
 | 
|  | 16 | +import numbers | 
| 16 | 17 | from typing import Any, Dict, Optional, Tuple | 
| 17 | 18 | 
 | 
| 18 | 19 | import torch | 
| @@ -54,6 +55,34 @@ def forward(self, hidden_states, scale=None): | 
| 54 | 55 |         return hidden_states | 
| 55 | 56 | 
 | 
| 56 | 57 | 
 | 
|  | 58 | +class MochiRMSNorm(nn.Module): | 
|  | 59 | +    def __init__(self, dim, eps: float, elementwise_affine: bool = True): | 
|  | 60 | +        super().__init__() | 
|  | 61 | + | 
|  | 62 | +        self.eps = eps | 
|  | 63 | + | 
|  | 64 | +        if isinstance(dim, numbers.Integral): | 
|  | 65 | +            dim = (dim,) | 
|  | 66 | + | 
|  | 67 | +        self.dim = torch.Size(dim) | 
|  | 68 | + | 
|  | 69 | +        if elementwise_affine: | 
|  | 70 | +            self.weight = nn.Parameter(torch.ones(dim)) | 
|  | 71 | +        else: | 
|  | 72 | +            self.weight = None | 
|  | 73 | + | 
|  | 74 | +    def forward(self, hidden_states): | 
|  | 75 | +        input_dtype = hidden_states.dtype | 
|  | 76 | +        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | 
|  | 77 | +        hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | 
|  | 78 | + | 
|  | 79 | +        if self.weight is not None: | 
|  | 80 | +            hidden_states = hidden_states * self.weight | 
|  | 81 | +        hidden_states = hidden_states.to(input_dtype) | 
|  | 82 | + | 
|  | 83 | +        return hidden_states | 
|  | 84 | + | 
|  | 85 | + | 
| 57 | 86 | class MochiLayerNormContinuous(nn.Module): | 
| 58 | 87 |     def __init__( | 
| 59 | 88 |         self, | 
| @@ -139,10 +168,10 @@ def __init__( | 
| 139 | 168 | 
 | 
| 140 | 169 |         self.heads = out_dim // dim_head if out_dim is not None else heads | 
| 141 | 170 | 
 | 
| 142 |  | -        self.norm_q = RMSNorm(dim_head, eps, True) | 
| 143 |  | -        self.norm_k = RMSNorm(dim_head, eps, True) | 
| 144 |  | -        self.norm_added_q = RMSNorm(dim_head, eps, True) | 
| 145 |  | -        self.norm_added_k = RMSNorm(dim_head, eps, True) | 
|  | 171 | +        self.norm_q = MochiRMSNorm(dim_head, eps, True) | 
|  | 172 | +        self.norm_k = MochiRMSNorm(dim_head, eps, True) | 
|  | 173 | +        self.norm_added_q = MochiRMSNorm(dim_head, eps, True) | 
|  | 174 | +        self.norm_added_k = MochiRMSNorm(dim_head, eps, True) | 
| 146 | 175 | 
 | 
| 147 | 176 |         self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) | 
| 148 | 177 |         self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) | 
|  | 
0 commit comments