Skip to content

Commit 316f38b

Browse files
committed
Feat: Add MLX backend and testing the DDPM schedular with MLX
This commit adds the mlx backend for diffusers library, the idea is to make it easy to run Diffusion models like JAX backend for TPUs.
1 parent e16fd93 commit 316f38b

File tree

10 files changed

+1437
-2
lines changed

10 files changed

+1437
-2
lines changed

setup.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
"isort>=5.5.4",
109109
"jax>=0.4.1",
110110
"jaxlib>=0.4.1",
111+
"mlx",
111112
"Jinja2",
112113
"k-diffusion>=0.0.12",
113114
"torchsde",
@@ -235,8 +236,13 @@ def run(self):
235236
else:
236237
extras["flax"] = deps_list("jax", "jaxlib", "flax")
237238

239+
if sys.platform == "darwin" and os.uname().machine == "arm64": # Apple Silicon
240+
extras["mlx"] = deps_list("mlx")
241+
else:
242+
extras["mlx"] = []
243+
238244
extras["dev"] = (
239-
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
245+
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"] + extras["mlx"]
240246
)
241247

242248
install_requires = [
@@ -255,7 +261,7 @@ def run(self):
255261
setup(
256262
name="diffusers",
257263
version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
258-
description="State-of-the-art diffusion in PyTorch and JAX.",
264+
description="State-of-the-art diffusion in PyTorch, JAX and MLX.",
259265
long_description=open("README.md", "r", encoding="utf-8").read(),
260266
long_description_content_type="text/markdown",
261267
keywords="deep learning diffusion jax pytorch stable diffusion audioldm",

src/diffusers/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
OptionalDependencyNotAvailable,
88
_LazyModule,
99
is_flax_available,
10+
is_mlx_available,
1011
is_k_diffusion_available,
1112
is_librosa_available,
1213
is_note_seq_available,
@@ -35,6 +36,7 @@
3536
"utils": [
3637
"OptionalDependencyNotAvailable",
3738
"is_flax_available",
39+
"is_mlx_available",
3840
"is_inflect_available",
3941
"is_invisible_watermark_available",
4042
"is_k_diffusion_available",
@@ -519,6 +521,26 @@
519521
]
520522
)
521523

524+
try:
525+
if not is_mlx_available():
526+
raise OptionalDependencyNotAvailable()
527+
except OptionalDependencyNotAvailable:
528+
from .utils import dummy_mlx_objects # noqa F403
529+
530+
_import_structure["utils.dummy_mlx_objects"] = [
531+
name for name in dir(dummy_mlx_objects) if not name.startswith("_")
532+
]
533+
534+
535+
else:
536+
_import_structure["schedulers"].extend(
537+
[
538+
"MLXDDPMScheduler",
539+
"MLXEulerDiscreteScheduler",
540+
]
541+
)
542+
543+
522544
try:
523545
if not (is_note_seq_available()):
524546
raise OptionalDependencyNotAvailable()

src/diffusers/schedulers/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
_LazyModule,
2121
get_objects_from_module,
2222
is_flax_available,
23+
is_mlx_available,
2324
is_scipy_available,
2425
is_torch_available,
2526
is_torchsde_available,
@@ -99,6 +100,19 @@
99100
"broadcast_to_shape_from_left",
100101
]
101102

103+
try:
104+
if not is_mlx_available():
105+
raise OptionalDependencyNotAvailable()
106+
107+
except OptionalDependencyNotAvailable:
108+
from ..utils import dummy_mlx_objects # noqa F403
109+
110+
_dummy_modules.update(get_objects_from_module(dummy_mlx_objects))
111+
112+
else:
113+
_import_structure["scheduling_euler_discrete_mlx"] = ["MLXEulerDiscreteScheduler"]
114+
_import_structure["scheduling_ddpm_mlx"] = ["MLXDDPMScheduler"]
115+
102116

103117
try:
104118
if not (is_torch_available() and is_scipy_available()):
@@ -127,6 +141,7 @@
127141
from ..utils import (
128142
OptionalDependencyNotAvailable,
129143
is_flax_available,
144+
is_mlx_available,
130145
is_scipy_available,
131146
is_torch_available,
132147
is_torchsde_available,
@@ -197,6 +212,20 @@
197212
)
198213

199214
try:
215+
if not is_mlx_available():
216+
raise OptionalDependencyNotAvailable()
217+
except OptionalDependencyNotAvailable:
218+
from ..utils.dummy_mlx_objects import * # noqa F403
219+
else:
220+
from .scheduling_euler_discrete_mlx import MLXEulerDiscreteScheduler
221+
from .scheduling_ddpm_mlx import MLXDDPMScheduler
222+
from .scheduling_utils_mlx import (
223+
MLXKarrasDiffusionSchedulers,
224+
MLXSchedulerMixin,
225+
MLXSchedulerOutput
226+
)
227+
228+
try:
200229
if not (is_torch_available() and is_scipy_available()):
201230
raise OptionalDependencyNotAvailable()
202231
except OptionalDependencyNotAvailable:

0 commit comments

Comments
 (0)