Skip to content

Commit 975cfae

Browse files
committed
Add tests for mlx
1 parent 316f38b commit 975cfae

21 files changed

+4570
-2
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@
511511

512512

513513
else:
514+
_import_structure["models.vae_mlx"] = ["MLXAutoencoderKL"]
514515
_import_structure["pipelines"].extend(
515516
[
516517
"FlaxStableDiffusionControlNetPipeline",

src/diffusers/configuration_utils.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,56 @@ def init(self, *args, **kwargs):
703703
return cls
704704

705705

706+
def mlx_register_to_config(cls):
707+
original_init = cls.__init__
708+
709+
@functools.wraps(original_init)
710+
def init(self, *args, **kwargs):
711+
if not isinstance(self, ConfigMixin):
712+
raise RuntimeError(
713+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
714+
"not inherit from `ConfigMixin`."
715+
)
716+
717+
# Ignore private kwargs in the init. Retrieve all passed attributes
718+
init_kwargs = dict(kwargs.items())
719+
720+
# Retrieve default values
721+
fields = dataclasses.fields(self)
722+
default_kwargs = {}
723+
for field in fields:
724+
# ignore flax specific attributes
725+
if field.name in self._flax_internal_args:
726+
continue
727+
if type(field.default) == dataclasses._MISSING_TYPE:
728+
default_kwargs[field.name] = None
729+
else:
730+
default_kwargs[field.name] = getattr(self, field.name)
731+
732+
# Make sure init_kwargs override default kwargs
733+
new_kwargs = {**default_kwargs, **init_kwargs}
734+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
735+
if "dtype" in new_kwargs:
736+
new_kwargs.pop("dtype")
737+
738+
# Get positional arguments aligned with kwargs
739+
for i, arg in enumerate(args):
740+
name = fields[i].name
741+
new_kwargs[name] = arg
742+
743+
# Take note of the parameters that were not present in the loaded config
744+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
745+
new_kwargs["_use_default_values"] = list(
746+
set(new_kwargs.keys()) - set(init_kwargs)
747+
)
748+
749+
getattr(self, "register_to_config")(**new_kwargs)
750+
original_init(self, *args, **kwargs)
751+
752+
cls.__init__ = init
753+
return cls
754+
755+
706756
class LegacyConfigMixin(ConfigMixin):
707757
r"""
708758
A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more

src/diffusers/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
DIFFUSERS_SLOW_IMPORT,
1919
_LazyModule,
2020
is_flax_available,
21+
is_mlx_available,
2122
is_torch_available,
2223
)
2324

@@ -73,6 +74,10 @@
7374
_import_structure["unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
7475
_import_structure["vae_flax"] = ["FlaxAutoencoderKL"]
7576

77+
if is_mlx_available():
78+
_import_structure["unets.unet_2d_condition_mlx"] = ["MLXUNet2DConditionModel"]
79+
_import_structure["vae_mlx"] = ["MLXAutoencoderKL"]
80+
7681

7782
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
7883
if is_torch_available():
@@ -131,6 +136,10 @@
131136
from .unets import FlaxUNet2DConditionModel
132137
from .vae_flax import FlaxAutoencoderKL
133138

139+
if is_mlx_available():
140+
from .unets import MLXUNet2DConditionModel
141+
from .vae_mlx import MLXAutoencoderKL
142+
134143
else:
135144
import sys
136145

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
# Copyright 2024 Apple and The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.huggingface/diffusers.git
14+
import functools
15+
import math
16+
17+
import mlx.core as mx
18+
import mlx.nn as nn
19+
20+
## Inspired from example here: https://github.com/ml-explore/mlx-examples/blob/main/stable_diffusion/stable_diffusion/unet.py
21+
## I am using complete implementations as in Flax and pyTorch implementation of diffusers to make it easy tp maintain the library in future with the similar API
22+
23+
class MLXAttention(nn.Module):
24+
r"""
25+
A MLX multi-head attention module as described in: https://arxiv.org/abs/1706.03762
26+
Used the mlx base MultiHeadAttention implementation with some changes.
27+
28+
Parameters:
29+
dims (int): The model dimensions. This is also the default
30+
value for the queries, keys, values, and the output.
31+
num_heads (int): The number of attention heads to use.
32+
query_input_dims (int, optional): The input dimensions of the queries.
33+
Default: ``dims``.
34+
key_input_dims (int, optional): The input dimensions of the keys.
35+
Default: ``dims``.
36+
value_input_dims (int, optional): The input dimensions of the values.
37+
Default: ``key_input_dims``.
38+
value_dims (int, optional): The dimensions of the values after the
39+
projection. Default: ``dims``.
40+
value_output_dims (int, optional): The dimensions the new values will
41+
be projected to. Default: ``dims``.
42+
dropout (`float`, *optional*, defaults to 0.0):
43+
The dropout probability to use.
44+
bias (bool, optional): Whether or not to use a bias in the projections.
45+
Default: ``False``.
46+
"""
47+
def __init__(
48+
self,
49+
dims: int,
50+
num_heads: int,
51+
query_input_dims: Optional[int] = None,
52+
key_input_dims: Optional[int] = None,
53+
value_input_dims: Optional[int] = None,
54+
value_dims: Optional[int] = None,
55+
value_output_dims: Optional[int] = None,
56+
bias: bool = False,
57+
):
58+
query_input_dims = query_input_dims or dims
59+
key_input_dims = key_input_dims or dims
60+
value_input_dims = value_input_dims or key_input_dims
61+
value_dims = value_dims or dims
62+
value_output_dims = value_output_dims or dims
63+
64+
self.num_heads = num_heads
65+
66+
inner_dim = self.dim_head * self.heads
67+
self.scale = self.dim_head**-0.5
68+
69+
self.to_q = Linear(query_input_dims, dims, bias=bias)
70+
self.to_k = Linear(key_input_dims, dims, bias=bias)
71+
self.to_v = Linear(value_input_dims, value_dims, bias=bias)
72+
self.to_out_0 = Linear(value_dims, value_output_dims, bias=bias)
73+
74+
def __call__(self, hidden_states, context=None, mask=None):
75+
context = hidden_states if context is None else context
76+
77+
queries = self.to_q(hidden_states)
78+
keys = self.to_k(context)
79+
values = self.to_v(context)
80+
81+
num_heads = self.num_heads
82+
B, L, D = queries.shape
83+
_, S, _ = keys.shape
84+
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
85+
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
86+
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
87+
88+
# Dimensions are [batch x num heads x sequence x hidden dim]
89+
scale = math.sqrt(1 / queries.shape[-1])
90+
scores = (queries * scale) @ keys
91+
if mask is not None:
92+
scores = scores + mask.astype(scores.dtype)
93+
scores = mx.softmax(scores, axis=-1)
94+
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
95+
96+
return self.to_out_0(values_hat)
97+
98+
99+
class MLXBasicTransformerBlock(nn.Module):
100+
r"""
101+
A MLX transformer block layer.
102+
103+
Parameters:
104+
model_dims (:obj:`int`):
105+
Inner hidden states dimension
106+
n_heads (:obj:`int`):
107+
Number of heads
108+
hidden_dims (:obj:`int`):
109+
Hidden states dimension
110+
"""
111+
def __init__(
112+
self,
113+
model_dims: int,
114+
n_heads: int,
115+
hidden_dims: Optional[int] = None,
116+
memory_dims: Optional[int] = None,
117+
):
118+
super().__init__()
119+
120+
self.norm1 = nn.LayerNorm(model_dims)
121+
self.attn1 = MLXAttention(model_dims, n_heads)
122+
self.attn1.to_out_0.bias = mx.zeros(model_dims)
123+
124+
memory_dims = memory_dims or model_dims
125+
self.norm2 = nn.LayerNorm(model_dims)
126+
self.attn2 = MLXAttention(
127+
model_dims, n_heads, key_input_dims=memory_dims
128+
)
129+
self.attn2.to_out_0.bias = mx.zeros(model_dims)
130+
131+
hidden_dims = hidden_dims or 4 * model_dims
132+
self.norm3 = nn.LayerNorm(model_dims)
133+
self.ff = MLXFeedForward(model_dims, hidden_dims)
134+
135+
def __call__(self, residual, hidden_states=None, attn_mask=None):
136+
# self attention
137+
residual = hidden_states
138+
hidden_states = self.attn1(self.norm1(hidden_states), attn_mask)
139+
hidden_states = hidden_states + residual
140+
141+
# cross attention
142+
residual = hidden_states
143+
hidden_states = self.attn2(
144+
self.norm2(hidden_states), context, attn_mask
145+
)
146+
hidden_states = hidden_states + residual
147+
148+
# feed forward
149+
residual = hidden_states
150+
hidden_states = self.ff(self.norm3(hidden_states))
151+
hidden_states = hidden_states + residual
152+
153+
return hidden_states
154+
155+
156+
class MLXTransformer2DModel(nn.Module):
157+
r"""
158+
A transformer model for inputs with 2 spatial dimensions.
159+
160+
Parameters:
161+
in_channels (:obj:`int`):
162+
Input number of channels
163+
n_heads (:obj:`int`):
164+
Number of heads
165+
d_head (:obj:`int`):
166+
Hidden states dimension inside each head
167+
depth (:obj:`int`, *optional*, defaults to 1):
168+
Number of transformers block
169+
dropout (:obj:`float`, *optional*, defaults to 0.0):
170+
Dropout rate
171+
"""
172+
def __init__(
173+
self,
174+
in_channels: int,
175+
model_dims: int,
176+
encoder_dims: int,
177+
num_heads: int,
178+
num_layers: int = 1,
179+
norm_num_groups: int = 32,
180+
):
181+
super().__init__()
182+
183+
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
184+
self.proj_in = nn.Linear(in_channels, model_dims)
185+
self.transformer_blocks = [
186+
MLXBasicTransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
187+
for i in range(num_layers)
188+
]
189+
self.proj_out = nn.Linear(model_dims, in_channels)
190+
191+
def __call__(self, hidden_states, context, attn_mask):
192+
# Save the input to add to the output
193+
residual = hidden_states
194+
dtype = hidden_statesx.dtype
195+
196+
# Perform the input norm and projection
197+
B, H, W, C = hidden_states.shape
198+
hidden_states = self.norm(hidden_states).reshape(B, -1, C)
199+
hidden_states = self.proj_in(hidden_states)
200+
201+
# Apply the transformer
202+
for block in self.transformer_blocks:
203+
hidden_states = block(hidden_states, context, attn_mask)
204+
205+
# Apply the output projection and reshape
206+
hidden_states = self.proj_out(hidden_states)
207+
hidden_states = hidden_states.reshape(B, H, W, C)
208+
209+
return hidden_states + residual
210+
211+
212+
class MLXFeedForward(nn.Module):
213+
r"""
214+
MLX module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
215+
[`FeedForward`] class, with the following simplifications:
216+
- The activation function is currently hardcoded to a gated linear unit from:
217+
https://arxiv.org/abs/2002.05202
218+
- `dim_out` is equal to `dim`.
219+
- The number of hidden dimensions is hardcoded to `dim * 4` in [`MLXGELU`].
220+
221+
Parameters:
222+
model_dims (:obj:`int`):
223+
Model input states dimension
224+
hidden_dims (:obj:`int`):
225+
Inner hidden states dimension
226+
dropout (:obj:`float`, *optional*, defaults to 0.0):
227+
Dropout rate
228+
"""
229+
def __init__(
230+
self,
231+
model_dims: int,
232+
hidden_dims: int = None,
233+
dropout: float = 0.0,
234+
):
235+
# The second linear layer needs to be called
236+
# net_2 for now to match the index of the Sequential layer
237+
hidden_dims = hidden_dims or 4 * model_dims
238+
self.net_0 = MLXGEGLU(model_dims, hidden_dims, dropout)
239+
self.net_2 = nn.Linear(hidden_dims, model_dims)
240+
241+
def __call__(self, hidden_states):
242+
hidden_states = self.net_0(hidden_states)
243+
hidden_states = self.net_2(hidden_states)
244+
return hidden_states
245+
246+
247+
class MLXGEGLU(nn.Module):
248+
r"""
249+
MLX implementation of a Linear layer followed by the variant of the gated linear unit activation function from
250+
https://arxiv.org/abs/2002.05202.
251+
252+
Parameters:
253+
model_dims (:obj:`int`):
254+
Input hidden states dimension
255+
hidden_dims (:obj:`int`):
256+
hidden states dimension
257+
dropout (:obj:`float`, *optional*, defaults to 0.0):
258+
Dropout rate
259+
"""
260+
def __init__(
261+
self, model_dims: int, hidden_dims: int=None, dropout: float = 0.0
262+
):
263+
self.proj = nn.Linear(model_dims, hidden_dims*2)
264+
self.dropout_layer = nn.Dropout(rate=self.dropout)
265+
266+
def __call__(self, hidden_states):
267+
hidden_states = self.proj(hidden_states)
268+
hidden_linear, hidden_gelu = mx.split(hidden_states, 2, axis=2)
269+
return self.dropout_layer(
270+
hidden_linear * nn.gelu(hidden_gelu)
271+
)

0 commit comments

Comments
 (0)