Skip to content

Commit cb4e44b

Browse files
authored
Merge branch 'main' into add-trtquant-backend
2 parents 1a8806f + 6549b04 commit cb4e44b

File tree

9 files changed

+132
-94
lines changed

9 files changed

+132
-94
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ from diffusers.image_processor import VaeImageProcessor
223223
import torch
224224

225225
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
226-
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
226+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
227227
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
228228

229229
with torch.no_grad():

docs/source/en/tutorials/autopipeline.md

Lines changed: 29 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -12,112 +12,56 @@ specific language governing permissions and limitations under the License.
1212

1313
# AutoPipeline
1414

15-
Diffusers provides many pipelines for basic tasks like generating images, videos, audio, and inpainting. On top of these, there are specialized pipelines for adapters and features like upscaling, super-resolution, and more. Different pipeline classes can even use the same checkpoint because they share the same pretrained model! With so many different pipelines, it can be overwhelming to know which pipeline class to use.
15+
[AutoPipeline](../api/models/auto_model) is a *task-and-model* pipeline that automatically selects the correct pipeline subclass based on the task. It handles the complexity of loading different pipeline subclasses without needing to know the specific pipeline subclass name.
1616

17-
The [AutoPipeline](../api/pipelines/auto_pipeline) class is designed to simplify the variety of pipelines in Diffusers. It is a generic *task-first* pipeline that lets you focus on a task ([`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`], and [`AutoPipelineForInpainting`]) without needing to know the specific pipeline class. The [AutoPipeline](../api/pipelines/auto_pipeline) automatically detects the correct pipeline class to use.
17+
This is unlike [`DiffusionPipeline`], a *model-only* pipeline that automatically selects the pipeline subclass based on the model.
1818

19-
For example, let's use the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint.
20-
21-
Under the hood, [AutoPipeline](../api/pipelines/auto_pipeline):
22-
23-
1. Detects a `"stable-diffusion"` class from the [model_index.json](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0/blob/main/model_index.json) file.
24-
2. Depending on the task you're interested in, it loads the [`StableDiffusionPipeline`], [`StableDiffusionImg2ImgPipeline`], or [`StableDiffusionInpaintPipeline`]. Any parameter (`strength`, `num_inference_steps`, etc.) you would pass to these specific pipelines can also be passed to the [AutoPipeline](../api/pipelines/auto_pipeline).
25-
26-
<hfoptions id="autopipeline">
27-
<hfoption id="text-to-image">
19+
[`AutoPipelineForImage2Image`] returns a specific pipeline subclass, (for example, [`StableDiffusionXLImg2ImgPipeline`]), which can only be used for image-to-image tasks.
2820

2921
```py
30-
from diffusers import AutoPipelineForText2Image
3122
import torch
32-
33-
pipe_txt2img = AutoPipelineForText2Image.from_pretrained(
34-
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
35-
).to("cuda")
36-
37-
prompt = "cinematic photo of Godzilla eating sushi with a cat in a izakaya, 35mm photograph, film, professional, 4k, highly detailed"
38-
generator = torch.Generator(device="cpu").manual_seed(37)
39-
image = pipe_txt2img(prompt, generator=generator).images[0]
40-
image
41-
```
42-
43-
<div class="flex justify-center">
44-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png"/>
45-
</div>
46-
47-
</hfoption>
48-
<hfoption id="image-to-image">
49-
50-
```py
5123
from diffusers import AutoPipelineForImage2Image
52-
from diffusers.utils import load_image
53-
import torch
54-
55-
pipe_img2img = AutoPipelineForImage2Image.from_pretrained(
56-
"dreamlike-art/dreamlike-photoreal-2.0", torch_dtype=torch.float16, use_safetensors=True
57-
).to("cuda")
58-
59-
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-text2img.png")
60-
61-
prompt = "cinematic photo of Godzilla eating burgers with a cat in a fast food restaurant, 35mm photograph, film, professional, 4k, highly detailed"
62-
generator = torch.Generator(device="cpu").manual_seed(53)
63-
image = pipe_img2img(prompt, image=init_image, generator=generator).images[0]
64-
image
65-
```
66-
67-
Notice how the [dreamlike-art/dreamlike-photoreal-2.0](https://hf.co/dreamlike-art/dreamlike-photoreal-2.0) checkpoint is used for both text-to-image and image-to-image tasks? To save memory and avoid loading the checkpoint twice, use the [`~DiffusionPipeline.from_pipe`] method.
6824

69-
```py
70-
pipe_img2img = AutoPipelineForImage2Image.from_pipe(pipe_txt2img).to("cuda")
71-
image = pipeline(prompt, image=init_image, generator=generator).images[0]
72-
image
25+
pipeline = AutoPipelineForImage2Image.from_pretrained(
26+
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
27+
)
28+
print(pipeline)
29+
"StableDiffusionXLImg2ImgPipeline {
30+
"_class_name": "StableDiffusionXLImg2ImgPipeline",
31+
...
32+
"
7333
```
7434

75-
You can learn more about the [`~DiffusionPipeline.from_pipe`] method in the [Reuse a pipeline](../using-diffusers/loading#reuse-a-pipeline) guide.
76-
77-
<div class="flex justify-center">
78-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png"/>
79-
</div>
80-
81-
</hfoption>
82-
<hfoption id="inpainting">
35+
Loading the same model with [`DiffusionPipeline`] returns the [`StableDiffusionXLPipeline`] subclass. It can be used for text-to-image, image-to-image, or inpainting tasks depending on the inputs.
8336

8437
```py
85-
from diffusers import AutoPipelineForInpainting
86-
from diffusers.utils import load_image
8738
import torch
39+
from diffusers import DiffusionPipeline
8840

89-
pipeline = AutoPipelineForInpainting.from_pretrained(
90-
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True
91-
).to("cuda")
92-
93-
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-img2img.png")
94-
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-mask.png")
95-
96-
prompt = "cinematic photo of a owl, 35mm photograph, film, professional, 4k, highly detailed"
97-
generator = torch.Generator(device="cpu").manual_seed(38)
98-
image = pipeline(prompt, image=init_image, mask_image=mask_image, generator=generator, strength=0.4).images[0]
99-
image
41+
pipeline = DiffusionPipeline.from_pretrained(
42+
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.bfloat16, device_map="cuda",
43+
)
44+
print(pipeline)
45+
"StableDiffusionXLPipeline {
46+
"_class_name": "StableDiffusionXLPipeline",
47+
...
48+
"
10049
```
10150

102-
<div class="flex justify-center">
103-
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/autopipeline-inpaint.png"/>
104-
</div>
51+
Check the [mappings](https://github.com/huggingface/diffusers/blob/130fd8df54f24ffb006d84787b598d8adc899f23/src/diffusers/pipelines/auto_pipeline.py#L114) to see whether a model is supported or not.
10552

106-
</hfoption>
107-
</hfoptions>
108-
109-
## Unsupported checkpoints
110-
111-
The [AutoPipeline](../api/pipelines/auto_pipeline) supports [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [Stable Diffusion XL](../api/pipelines/stable_diffusion/stable_diffusion_xl), [ControlNet](../api/pipelines/controlnet), [Kandinsky 2.1](../api/pipelines/kandinsky.md), [Kandinsky 2.2](../api/pipelines/kandinsky_v22), and [DeepFloyd IF](../api/pipelines/deepfloyd_if) checkpoints.
112-
113-
If you try to load an unsupported checkpoint, you'll get an error.
53+
Trying to load an unsupported model returns an error.
11454

11555
```py
116-
from diffusers import AutoPipelineForImage2Image
11756
import torch
57+
from diffusers import AutoPipelineForImage2Image
11858

11959
pipeline = AutoPipelineForImage2Image.from_pretrained(
120-
"openai/shap-e-img2img", torch_dtype=torch.float16, use_safetensors=True
60+
"openai/shap-e-img2img", torch_dtype=torch.float16,
12161
)
12262
"ValueError: AutoPipeline can't find a pipeline linked to ShapEImg2ImgPipeline for None"
12363
```
64+
65+
There are three types of [AutoPipeline](../api/models/auto_model) classes, [`AutoPipelineForText2Image`], [`AutoPipelineForImage2Image`] and [`AutoPipelineForInpainting`]. Each of these classes have a predefined mapping, linking a pipeline to their task-specific subclass.
66+
67+
When [`~AutoPipelineForText2Image.from_pretrained`] is called, it extracts the class name from the `model_index.json` file and selects the appropriate pipeline subclass for the task based on the mapping.

docs/source/zh/training/distributed_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ from diffusers.image_processor import VaeImageProcessor
223223
import torch
224224

225225
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
226-
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
226+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
227227
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
228228

229229
with torch.no_grad():

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1399,6 +1399,7 @@ def main(args):
13991399
torch_dtype = torch.float16
14001400
elif args.prior_generation_precision == "bf16":
14011401
torch_dtype = torch.bfloat16
1402+
14021403
pipeline = FluxPipeline.from_pretrained(
14031404
args.pretrained_model_name_or_path,
14041405
torch_dtype=torch_dtype,
@@ -1419,7 +1420,8 @@ def main(args):
14191420
for example in tqdm(
14201421
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
14211422
):
1422-
images = pipeline(example["prompt"]).images
1423+
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
1424+
images = pipeline(prompt=example["prompt"]).images
14231425

14241426
for i, image in enumerate(images):
14251427
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def main(args):
11311131
torch_dtype = torch.float16
11321132
elif args.prior_generation_precision == "bf16":
11331133
torch_dtype = torch.bfloat16
1134+
11341135
pipeline = FluxPipeline.from_pretrained(
11351136
args.pretrained_model_name_or_path,
11361137
torch_dtype=torch_dtype,
@@ -1151,16 +1152,16 @@ def main(args):
11511152
for example in tqdm(
11521153
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
11531154
):
1154-
images = pipeline(example["prompt"]).images
1155+
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
1156+
images = pipeline(prompt=example["prompt"]).images
11551157

11561158
for i, image in enumerate(images):
11571159
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
11581160
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
11591161
image.save(image_filename)
11601162

11611163
del pipeline
1162-
if torch.cuda.is_available():
1163-
torch.cuda.empty_cache()
1164+
free_memory()
11641165

11651166
# Handle the repository creation
11661167
if accelerator.is_main_process:
@@ -1728,6 +1729,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17281729
device=accelerator.device,
17291730
prompt=args.instance_prompt,
17301731
)
1732+
else:
1733+
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
1734+
prompts, text_encoders, tokenizers
1735+
)
17311736

17321737
# Convert images to latent space
17331738
if args.cache_latents:

src/diffusers/models/attention_dispatch.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
is_flash_attn_3_available,
2727
is_flash_attn_available,
2828
is_flash_attn_version,
29+
is_kernels_available,
2930
is_sageattention_available,
3031
is_sageattention_version,
3132
is_torch_npu_available,
@@ -35,7 +36,7 @@
3536
is_xformers_available,
3637
is_xformers_version,
3738
)
38-
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
39+
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
3940

4041

4142
_REQUIRED_FLASH_VERSION = "2.6.3"
@@ -67,6 +68,17 @@
6768
flash_attn_3_func = None
6869
flash_attn_3_varlen_func = None
6970

71+
if DIFFUSERS_ENABLE_HUB_KERNELS:
72+
if not is_kernels_available():
73+
raise ImportError(
74+
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
75+
)
76+
from ..utils.kernels_utils import _get_fa3_from_hub
77+
78+
flash_attn_interface_hub = _get_fa3_from_hub()
79+
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
80+
else:
81+
flash_attn_3_func_hub = None
7082

7183
if _CAN_USE_SAGE_ATTN:
7284
from sageattention import (
@@ -153,6 +165,8 @@ class AttentionBackendName(str, Enum):
153165
FLASH_VARLEN = "flash_varlen"
154166
_FLASH_3 = "_flash_3"
155167
_FLASH_VARLEN_3 = "_flash_varlen_3"
168+
_FLASH_3_HUB = "_flash_3_hub"
169+
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
156170

157171
# PyTorch native
158172
FLEX = "flex"
@@ -351,6 +365,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
351365
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
352366
)
353367

368+
# TODO: add support Hub variant of FA3 varlen later
369+
elif backend in [AttentionBackendName._FLASH_3_HUB]:
370+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
371+
raise RuntimeError(
372+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
373+
)
374+
if not is_kernels_available():
375+
raise RuntimeError(
376+
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
377+
)
378+
354379
elif backend in [
355380
AttentionBackendName.SAGE,
356381
AttentionBackendName.SAGE_VARLEN,
@@ -657,6 +682,44 @@ def _flash_attention_3(
657682
return (out, lse) if return_attn_probs else out
658683

659684

685+
@_AttentionBackendRegistry.register(
686+
AttentionBackendName._FLASH_3_HUB,
687+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
688+
)
689+
def _flash_attention_3_hub(
690+
query: torch.Tensor,
691+
key: torch.Tensor,
692+
value: torch.Tensor,
693+
scale: Optional[float] = None,
694+
is_causal: bool = False,
695+
window_size: Tuple[int, int] = (-1, -1),
696+
softcap: float = 0.0,
697+
deterministic: bool = False,
698+
return_attn_probs: bool = False,
699+
) -> torch.Tensor:
700+
out = flash_attn_3_func_hub(
701+
q=query,
702+
k=key,
703+
v=value,
704+
softmax_scale=scale,
705+
causal=is_causal,
706+
qv=None,
707+
q_descale=None,
708+
k_descale=None,
709+
v_descale=None,
710+
window_size=window_size,
711+
softcap=softcap,
712+
num_splits=1,
713+
pack_gqa=None,
714+
deterministic=deterministic,
715+
sm_margin=0,
716+
return_attn_probs=return_attn_probs,
717+
)
718+
# When `return_attn_probs` is True, the above returns a tuple of
719+
# actual outputs and lse.
720+
return (out[0], out[1]) if return_attn_probs else out
721+
722+
660723
@_AttentionBackendRegistry.register(
661724
AttentionBackendName._FLASH_VARLEN_3,
662725
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

src/diffusers/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
4747
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
49+
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
4950

5051
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5152
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ..utils import get_logger
2+
from .import_utils import is_kernels_available
3+
4+
5+
logger = get_logger(__name__)
6+
7+
8+
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
9+
10+
11+
def _get_fa3_from_hub():
12+
if not is_kernels_available():
13+
return None
14+
else:
15+
from kernels import get_kernel
16+
17+
try:
18+
# TODO: temporary revision for now. Remove when merged upstream into `main`.
19+
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
20+
return flash_attn_3_hub
21+
except Exception as e:
22+
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
23+
raise

tests/quantization/quanto/test_quanto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
nightly,
1414
numpy_cosine_similarity_distance,
1515
require_accelerate,
16-
require_big_accelerator,
16+
require_accelerator,
1717
require_torch_cuda_compatibility,
1818
torch_device,
1919
)
@@ -31,7 +31,7 @@
3131

3232

3333
@nightly
34-
@require_big_accelerator
34+
@require_accelerator
3535
@require_accelerate
3636
class QuantoBaseTesterMixin:
3737
model_id = None

0 commit comments

Comments
 (0)