Skip to content

Commit 6313645

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored
add StableDiffusionXLKDiffusionPipeline (#6447)
--------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 2d1f218 commit 6313645

File tree

9 files changed

+1215
-5
lines changed

9 files changed

+1215
-5
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,8 @@
333333
title: Latent upscaler
334334
- local: api/pipelines/stable_diffusion/upscale
335335
title: Super-resolution
336+
- local: api/pipelines/stable_diffusion/k_diffusion
337+
title: K-Diffusion
336338
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
337339
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
338340
- local: api/pipelines/stable_diffusion/adapter
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# K-Diffusion
14+
15+
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
16+
17+
Note that most the samplers from k-diffusion are implemented in Diffusers and we recommend using existing schedulers. You can find a mapping between k-diffusion samplers and schedulers in Diffusers [here](https://huggingface.co/docs/diffusers/api/schedulers/overview)
18+
19+
20+
## StableDiffusionKDiffusionPipeline
21+
22+
[[autodoc]] StableDiffusionKDiffusionPipeline
23+
24+
25+
## StableDiffusionXLKDiffusionPipeline
26+
27+
[[autodoc]] StableDiffusionXLKDiffusionPipeline

src/diffusers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@
316316
]
317317

318318
else:
319-
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline"])
319+
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
320320

321321
try:
322322
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
@@ -668,7 +668,7 @@
668668
except OptionalDependencyNotAvailable:
669669
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
670670
else:
671-
from .pipelines import StableDiffusionKDiffusionPipeline
671+
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
672672

673673
try:
674674
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):

src/diffusers/pipelines/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,10 @@
265265

266266
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
267267
else:
268-
_import_structure["stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
268+
_import_structure["stable_diffusion_k_diffusion"] = [
269+
"StableDiffusionKDiffusionPipeline",
270+
"StableDiffusionXLKDiffusionPipeline",
271+
]
269272
try:
270273
if not is_flax_available():
271274
raise OptionalDependencyNotAvailable()
@@ -491,7 +494,10 @@
491494
except OptionalDependencyNotAvailable:
492495
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
493496
else:
494-
from .stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
497+
from .stable_diffusion_k_diffusion import (
498+
StableDiffusionKDiffusionPipeline,
499+
StableDiffusionXLKDiffusionPipeline,
500+
)
495501

496502
try:
497503
if not is_flax_available():

src/diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
3131
else:
3232
_import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
33+
_import_structure["pipeline_stable_diffusion_xl_k_diffusion"] = ["StableDiffusionXLKDiffusionPipeline"]
3334

3435
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3536
try:
@@ -45,6 +46,7 @@
4546
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
4647
else:
4748
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
49+
from .pipeline_stable_diffusion_xl_k_diffusion import StableDiffusionXLKDiffusionPipeline
4850

4951
else:
5052
import sys

src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,15 @@ def __init__(
134134
def set_scheduler(self, scheduler_type: str):
135135
library = importlib.import_module("k_diffusion")
136136
sampling = getattr(library, "sampling")
137-
self.sampler = getattr(sampling, scheduler_type)
137+
try:
138+
self.sampler = getattr(sampling, scheduler_type)
139+
except Exception:
140+
valid_samplers = []
141+
for s in dir(sampling):
142+
if "sample_" in s:
143+
valid_samplers.append(s)
144+
145+
raise ValueError(f"Invalid scheduler type {scheduler_type}. Please choose one of {valid_samplers}.")
138146

139147
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
140148
def _encode_prompt(

0 commit comments

Comments
 (0)