|
| 1 | +# Copyright 2024 Apple and The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License.huggingface/diffusers.git |
| 14 | +import functools |
| 15 | +import math |
| 16 | + |
| 17 | +import mlx.core as mx |
| 18 | +import mlx.nn as nn |
| 19 | + |
| 20 | +## Inspired from example here: https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py |
| 21 | +## I am using complete implementations as in Flax and pyTorch implementation of diffusers to make it easy tp maintain the library in future with the similar API |
| 22 | + |
| 23 | +class MLXAttention(nn.Module): |
| 24 | + r""" |
| 25 | + A MLX multi-head attention module as described in: https://arxiv.org/abs/1706.03762 |
| 26 | + Used the mlx base MultiHeadAttention implementation with some changes. |
| 27 | +
|
| 28 | + Parameters: |
| 29 | + dims (int): The model dimensions. This is also the default |
| 30 | + value for the queries, keys, values, and the output. |
| 31 | + num_heads (int): The number of attention heads to use. |
| 32 | + query_input_dims (int, optional): The input dimensions of the queries. |
| 33 | + Default: ``dims``. |
| 34 | + key_input_dims (int, optional): The input dimensions of the keys. |
| 35 | + Default: ``dims``. |
| 36 | + value_input_dims (int, optional): The input dimensions of the values. |
| 37 | + Default: ``key_input_dims``. |
| 38 | + value_dims (int, optional): The dimensions of the values after the |
| 39 | + projection. Default: ``dims``. |
| 40 | + value_output_dims (int, optional): The dimensions the new values will |
| 41 | + be projected to. Default: ``dims``. |
| 42 | + dropout (`float`, *optional*, defaults to 0.0): |
| 43 | + The dropout probability to use. |
| 44 | + bias (bool, optional): Whether or not to use a bias in the projections. |
| 45 | + Default: ``False``. |
| 46 | + """ |
| 47 | + def __init__( |
| 48 | + self, |
| 49 | + dims: int, |
| 50 | + num_heads: int, |
| 51 | + query_input_dims: Optional[int] = None, |
| 52 | + key_input_dims: Optional[int] = None, |
| 53 | + value_input_dims: Optional[int] = None, |
| 54 | + value_dims: Optional[int] = None, |
| 55 | + value_output_dims: Optional[int] = None, |
| 56 | + bias: bool = False, |
| 57 | + ): |
| 58 | + query_input_dims = query_input_dims or dims |
| 59 | + key_input_dims = key_input_dims or dims |
| 60 | + value_input_dims = value_input_dims or key_input_dims |
| 61 | + value_dims = value_dims or dims |
| 62 | + value_output_dims = value_output_dims or dims |
| 63 | + |
| 64 | + self.num_heads = num_heads |
| 65 | + |
| 66 | + inner_dim = self.dim_head * self.heads |
| 67 | + self.scale = self.dim_head**-0.5 |
| 68 | + |
| 69 | + self.to_q = Linear(query_input_dims, dims, bias=bias) |
| 70 | + self.to_k = Linear(key_input_dims, dims, bias=bias) |
| 71 | + self.to_v = Linear(value_input_dims, value_dims, bias=bias) |
| 72 | + self.to_out_0 = Linear(value_dims, value_output_dims, bias=bias) |
| 73 | + |
| 74 | + def __call__(self, hidden_states, context=None, mask=None): |
| 75 | + context = hidden_states if context is None else context |
| 76 | + |
| 77 | + queries = self.to_q(hidden_states) |
| 78 | + keys = self.to_k(context) |
| 79 | + values = self.to_v(context) |
| 80 | + |
| 81 | + num_heads = self.num_heads |
| 82 | + B, L, D = queries.shape |
| 83 | + _, S, _ = keys.shape |
| 84 | + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) |
| 85 | + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) |
| 86 | + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) |
| 87 | + |
| 88 | + # Dimensions are [batch x num heads x sequence x hidden dim] |
| 89 | + scale = math.sqrt(1 / queries.shape[-1]) |
| 90 | + scores = (queries * scale) @ keys |
| 91 | + if mask is not None: |
| 92 | + scores = scores + mask.astype(scores.dtype) |
| 93 | + scores = mx.softmax(scores, axis=-1) |
| 94 | + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) |
| 95 | + |
| 96 | + return self.to_out_0(values_hat) |
| 97 | + |
| 98 | + |
| 99 | +class MLXBasicTransformerBlock(nn.Module): |
| 100 | + r""" |
| 101 | + A MLX transformer block layer. |
| 102 | +
|
| 103 | + Parameters: |
| 104 | + model_dims (:obj:`int`): |
| 105 | + Inner hidden states dimension |
| 106 | + n_heads (:obj:`int`): |
| 107 | + Number of heads |
| 108 | + hidden_dims (:obj:`int`): |
| 109 | + Hidden states dimension |
| 110 | + """ |
| 111 | + def __init__( |
| 112 | + self, |
| 113 | + model_dims: int, |
| 114 | + n_heads: int, |
| 115 | + hidden_dims: Optional[int] = None, |
| 116 | + memory_dims: Optional[int] = None, |
| 117 | + ): |
| 118 | + super().__init__() |
| 119 | + |
| 120 | + self.norm1 = nn.LayerNorm(model_dims) |
| 121 | + self.attn1 = MLXAttention(model_dims, n_heads) |
| 122 | + self.attn1.to_out_0.bias = mx.zeros(model_dims) |
| 123 | + |
| 124 | + memory_dims = memory_dims or model_dims |
| 125 | + self.norm2 = nn.LayerNorm(model_dims) |
| 126 | + self.attn2 = MLXAttention( |
| 127 | + model_dims, n_heads, key_input_dims=memory_dims |
| 128 | + ) |
| 129 | + self.attn2.to_out_0.bias = mx.zeros(model_dims) |
| 130 | + |
| 131 | + hidden_dims = hidden_dims or 4 * model_dims |
| 132 | + self.norm3 = nn.LayerNorm(model_dims) |
| 133 | + self.ff = MLXFeedForward(model_dims, hidden_dims) |
| 134 | + |
| 135 | + def __call__(self, residual, hidden_states=None, attn_mask=None): |
| 136 | + # self attention |
| 137 | + residual = hidden_states |
| 138 | + hidden_states = self.attn1(self.norm1(hidden_states), attn_mask) |
| 139 | + hidden_states = hidden_states + residual |
| 140 | + |
| 141 | + # cross attention |
| 142 | + residual = hidden_states |
| 143 | + hidden_states = self.attn2( |
| 144 | + self.norm2(hidden_states), context, attn_mask |
| 145 | + ) |
| 146 | + hidden_states = hidden_states + residual |
| 147 | + |
| 148 | + # feed forward |
| 149 | + residual = hidden_states |
| 150 | + hidden_states = self.ff(self.norm3(hidden_states)) |
| 151 | + hidden_states = hidden_states + residual |
| 152 | + |
| 153 | + return hidden_states |
| 154 | + |
| 155 | + |
| 156 | +class MLXTransformer2DModel(nn.Module): |
| 157 | + r""" |
| 158 | + A transformer model for inputs with 2 spatial dimensions. |
| 159 | +
|
| 160 | + Parameters: |
| 161 | + in_channels (:obj:`int`): |
| 162 | + Input number of channels |
| 163 | + n_heads (:obj:`int`): |
| 164 | + Number of heads |
| 165 | + d_head (:obj:`int`): |
| 166 | + Hidden states dimension inside each head |
| 167 | + depth (:obj:`int`, *optional*, defaults to 1): |
| 168 | + Number of transformers block |
| 169 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 170 | + Dropout rate |
| 171 | + """ |
| 172 | + def __init__( |
| 173 | + self, |
| 174 | + in_channels: int, |
| 175 | + model_dims: int, |
| 176 | + encoder_dims: int, |
| 177 | + num_heads: int, |
| 178 | + num_layers: int = 1, |
| 179 | + norm_num_groups: int = 32, |
| 180 | + ): |
| 181 | + super().__init__() |
| 182 | + |
| 183 | + self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True) |
| 184 | + self.proj_in = nn.Linear(in_channels, model_dims) |
| 185 | + self.transformer_blocks = [ |
| 186 | + MLXBasicTransformerBlock(model_dims, num_heads, memory_dims=encoder_dims) |
| 187 | + for i in range(num_layers) |
| 188 | + ] |
| 189 | + self.proj_out = nn.Linear(model_dims, in_channels) |
| 190 | + |
| 191 | + def __call__(self, hidden_states, context, attn_mask): |
| 192 | + # Save the input to add to the output |
| 193 | + residual = hidden_states |
| 194 | + dtype = hidden_statesx.dtype |
| 195 | + |
| 196 | + # Perform the input norm and projection |
| 197 | + B, H, W, C = hidden_states.shape |
| 198 | + hidden_states = self.norm(hidden_states).reshape(B, -1, C) |
| 199 | + hidden_states = self.proj_in(hidden_states) |
| 200 | + |
| 201 | + # Apply the transformer |
| 202 | + for block in self.transformer_blocks: |
| 203 | + hidden_states = block(hidden_states, context, attn_mask) |
| 204 | + |
| 205 | + # Apply the output projection and reshape |
| 206 | + hidden_states = self.proj_out(hidden_states) |
| 207 | + hidden_states = hidden_states.reshape(B, H, W, C) |
| 208 | + |
| 209 | + return hidden_states + residual |
| 210 | + |
| 211 | + |
| 212 | +class MLXFeedForward(nn.Module): |
| 213 | + r""" |
| 214 | + MLX module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's |
| 215 | + [`FeedForward`] class, with the following simplifications: |
| 216 | + - The activation function is currently hardcoded to a gated linear unit from: |
| 217 | + https://arxiv.org/abs/2002.05202 |
| 218 | + - `dim_out` is equal to `dim`. |
| 219 | + - The number of hidden dimensions is hardcoded to `dim * 4` in [`MLXGELU`]. |
| 220 | +
|
| 221 | + Parameters: |
| 222 | + model_dims (:obj:`int`): |
| 223 | + Model input states dimension |
| 224 | + hidden_dims (:obj:`int`): |
| 225 | + Inner hidden states dimension |
| 226 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 227 | + Dropout rate |
| 228 | + """ |
| 229 | + def __init__( |
| 230 | + self, |
| 231 | + model_dims: int, |
| 232 | + hidden_dims: int = None, |
| 233 | + dropout: float = 0.0, |
| 234 | + ): |
| 235 | + # The second linear layer needs to be called |
| 236 | + # net_2 for now to match the index of the Sequential layer |
| 237 | + hidden_dims = hidden_dims or 4 * model_dims |
| 238 | + self.net_0 = MLXGEGLU(model_dims, hidden_dims, dropout) |
| 239 | + self.net_2 = nn.Linear(hidden_dims, model_dims) |
| 240 | + |
| 241 | + def __call__(self, hidden_states): |
| 242 | + hidden_states = self.net_0(hidden_states) |
| 243 | + hidden_states = self.net_2(hidden_states) |
| 244 | + return hidden_states |
| 245 | + |
| 246 | + |
| 247 | +class MLXGEGLU(nn.Module): |
| 248 | + r""" |
| 249 | + MLX implementation of a Linear layer followed by the variant of the gated linear unit activation function from |
| 250 | + https://arxiv.org/abs/2002.05202. |
| 251 | +
|
| 252 | + Parameters: |
| 253 | + model_dims (:obj:`int`): |
| 254 | + Input hidden states dimension |
| 255 | + hidden_dims (:obj:`int`): |
| 256 | + hidden states dimension |
| 257 | + dropout (:obj:`float`, *optional*, defaults to 0.0): |
| 258 | + Dropout rate |
| 259 | + """ |
| 260 | + def __init__( |
| 261 | + self, model_dims: int, hidden_dims: int=None, dropout: float = 0.0 |
| 262 | + ): |
| 263 | + self.proj = nn.Linear(model_dims, hidden_dims*2) |
| 264 | + self.dropout_layer = nn.Dropout(rate=self.dropout) |
| 265 | + |
| 266 | + def __call__(self, hidden_states): |
| 267 | + hidden_states = self.proj(hidden_states) |
| 268 | + hidden_linear, hidden_gelu = mx.split(hidden_states, 2, axis=2) |
| 269 | + return self.dropout_layer( |
| 270 | + hidden_linear * nn.gelu(hidden_gelu) |
| 271 | + ) |
0 commit comments