Skip to content

Commit dd9dfd8

Browse files
committed
Merge branch 'bria_3_2_pipeline' of https://github.com/galbria/diffusers into bria_3_2_pipeline
2 parents 9c6d9dd + 51a3bdc commit dd9dfd8

File tree

13 files changed

+1147
-375
lines changed

13 files changed

+1147
-375
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- local: installation
66
title: Installation
77
- local: quicktour
8-
title: Quicktour
8+
title: Quickstart
99
- local: stable_diffusion
1010
title: Basic performance
1111

@@ -340,7 +340,7 @@
340340
title: AllegroTransformer3DModel
341341
- local: api/models/aura_flow_transformer2d
342342
title: AuraFlowTransformer2DModel
343-
- local: api/models/transformer_bria
343+
- local: api/models/bria_transformer
344344
title: BriaTransformer2DModel
345345
- local: api/models/chroma_transformer
346346
title: ChromaTransformer2DModel

docs/source/en/quicktour.md

Lines changed: 155 additions & 249 deletions
Large diffs are not rendered by default.

docs/source/en/stable_diffusion.md

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@ This guide recommends some basic performance tips for using the [`DiffusionPipel
2222

2323
Reducing the amount of memory used indirectly speeds up generation and can help a model fit on device.
2424

25+
The [`~DiffusionPipeline.enable_model_cpu_offload`] method moves a model to the CPU when it is not in use to save GPU memory.
26+
2527
```py
2628
import torch
2729
from diffusers import DiffusionPipeline
2830

2931
pipeline = DiffusionPipeline.from_pretrained(
3032
"stabilityai/stable-diffusion-xl-base-1.0",
31-
torch_dtype=torch.bfloat16
32-
).to("cuda")
33+
torch_dtype=torch.bfloat16,
34+
device_map="cuda"
35+
)
3336
pipeline.enable_model_cpu_offload()
3437

3538
prompt = """
@@ -44,7 +47,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
4447

4548
Denoising is the most computationally demanding process during diffusion. Methods that optimizes this process accelerates inference speed. Try the following methods for a speed up.
4649

47-
- Add `.to("cuda")` to place the pipeline on a GPU. Placing a model on an accelerator, like a GPU, increases speed because it performs computations in parallel.
50+
- Add `device_map="cuda"` to place the pipeline on a GPU. Placing a model on an accelerator, like a GPU, increases speed because it performs computations in parallel.
4851
- Set `torch_dtype=torch.bfloat16` to execute the pipeline in half-precision. Reducing the data type precision increases speed because it takes less time to perform computations in a lower precision.
4952

5053
```py
@@ -54,8 +57,9 @@ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
5457

5558
pipeline = DiffusionPipeline.from_pretrained(
5659
"stabilityai/stable-diffusion-xl-base-1.0",
57-
torch_dtype=torch.bfloat16
58-
).to("cuda")
60+
torch_dtype=torch.bfloat16,
61+
device_map="cuda
62+
)
5963
```
6064

6165
- Use a faster scheduler, such as [`DPMSolverMultistepScheduler`], which only requires ~20-25 steps.
@@ -88,8 +92,9 @@ Many modern diffusion models deliver high-quality images out-of-the-box. However
8892

8993
pipeline = DiffusionPipeline.from_pretrained(
9094
"stabilityai/stable-diffusion-xl-base-1.0",
91-
torch_dtype=torch.bfloat16
92-
).to("cuda")
95+
torch_dtype=torch.bfloat16,
96+
device_map="cuda"
97+
)
9398

9499
prompt = """
95100
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
@@ -109,8 +114,9 @@ Many modern diffusion models deliver high-quality images out-of-the-box. However
109114

110115
pipeline = DiffusionPipeline.from_pretrained(
111116
"stabilityai/stable-diffusion-xl-base-1.0",
112-
torch_dtype=torch.bfloat16
113-
).to("cuda")
117+
torch_dtype=torch.bfloat16,
118+
device_map="cuda"
119+
)
114120
pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
115121

116122
prompt = """

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@
494494
"QwenImageImg2ImgPipeline",
495495
"QwenImageInpaintPipeline",
496496
"QwenImagePipeline",
497+
"QwenImageEditPipeline",
497498
"ReduxImageEncoder",
498499
"SanaControlNetPipeline",
499500
"SanaPAGPipeline",
@@ -1127,6 +1128,7 @@
11271128
PixArtAlphaPipeline,
11281129
PixArtSigmaPAGPipeline,
11291130
PixArtSigmaPipeline,
1131+
QwenImageEditPipeline,
11301132
QwenImageImg2ImgPipeline,
11311133
QwenImageInpaintPipeline,
11321134
QwenImagePipeline,

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import functools
1716
import math
1817
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -161,17 +160,17 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161160
super().__init__()
162161
self.theta = theta
163162
self.axes_dim = axes_dim
164-
pos_index = torch.arange(1024)
165-
neg_index = torch.arange(1024).flip(0) * -1 - 1
166-
pos_freqs = torch.cat(
163+
pos_index = torch.arange(4096)
164+
neg_index = torch.arange(4096).flip(0) * -1 - 1
165+
self.pos_freqs = torch.cat(
167166
[
168167
self.rope_params(pos_index, self.axes_dim[0], self.theta),
169168
self.rope_params(pos_index, self.axes_dim[1], self.theta),
170169
self.rope_params(pos_index, self.axes_dim[2], self.theta),
171170
],
172171
dim=1,
173172
)
174-
neg_freqs = torch.cat(
173+
self.neg_freqs = torch.cat(
175174
[
176175
self.rope_params(neg_index, self.axes_dim[0], self.theta),
177176
self.rope_params(neg_index, self.axes_dim[1], self.theta),
@@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180179
dim=1,
181180
)
182181
self.rope_cache = {}
183-
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
184-
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
185182

186-
# 是否使用 scale rope
183+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
187184
self.scale_rope = scale_rope
188185

189186
def rope_params(self, index, dim, theta=10000):
@@ -201,35 +198,47 @@ def forward(self, video_fhw, txt_seq_lens, device):
201198
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
202199
txt_length: [bs] a list of 1 integers representing the length of the text
203200
"""
201+
if self.pos_freqs.device != device:
202+
self.pos_freqs = self.pos_freqs.to(device)
203+
self.neg_freqs = self.neg_freqs.to(device)
204+
204205
if isinstance(video_fhw, list):
205206
video_fhw = video_fhw[0]
206-
frame, height, width = video_fhw
207-
rope_key = f"{frame}_{height}_{width}"
208-
209-
if not torch.compiler.is_compiling():
210-
if rope_key not in self.rope_cache:
211-
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width)
212-
vid_freqs = self.rope_cache[rope_key]
213-
else:
214-
vid_freqs = self._compute_video_freqs(frame, height, width)
207+
if not isinstance(video_fhw, list):
208+
video_fhw = [video_fhw]
209+
210+
vid_freqs = []
211+
max_vid_index = 0
212+
for idx, fhw in enumerate(video_fhw):
213+
frame, height, width = fhw
214+
rope_key = f"{idx}_{height}_{width}"
215+
216+
if not torch.compiler.is_compiling():
217+
if rope_key not in self.rope_cache:
218+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
219+
video_freq = self.rope_cache[rope_key]
220+
else:
221+
video_freq = self._compute_video_freqs(frame, height, width, idx)
222+
vid_freqs.append(video_freq)
215223

216-
if self.scale_rope:
217-
max_vid_index = max(height // 2, width // 2)
218-
else:
219-
max_vid_index = max(height, width)
224+
if self.scale_rope:
225+
max_vid_index = max(height // 2, width // 2, max_vid_index)
226+
else:
227+
max_vid_index = max(height, width, max_vid_index)
220228

221229
max_len = max(txt_seq_lens)
222230
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
231+
vid_freqs = torch.cat(vid_freqs, dim=0)
223232

224233
return vid_freqs, txt_freqs
225234

226235
@functools.lru_cache(maxsize=None)
227-
def _compute_video_freqs(self, frame, height, width):
236+
def _compute_video_freqs(self, frame, height, width, idx=0):
228237
seq_lens = frame * height * width
229238
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
230239
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
231240

232-
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
241+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
233242
if self.scale_rope:
234243
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
235244
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@
392392
"QwenImagePipeline",
393393
"QwenImageImg2ImgPipeline",
394394
"QwenImageInpaintPipeline",
395+
"QwenImageEditPipeline",
395396
]
396397
try:
397398
if not is_onnx_available():
@@ -710,7 +711,12 @@
710711
from .paint_by_example import PaintByExamplePipeline
711712
from .pia import PIAPipeline
712713
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
713-
from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline
714+
from .qwenimage import (
715+
QwenImageEditPipeline,
716+
QwenImageImg2ImgPipeline,
717+
QwenImageInpaintPipeline,
718+
QwenImagePipeline,
719+
)
714720
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
715721
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
716722
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
2727
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
2828
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
29+
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
2930

3031
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3132
try:
@@ -35,6 +36,7 @@
3536
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
3637
else:
3738
from .pipeline_qwenimage import QwenImagePipeline
39+
from .pipeline_qwenimage_edit import QwenImageEditPipeline
3840
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
3941
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
4042
else:

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ def encode_prompt(
253253
if prompt_embeds is None:
254254
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
255255

256+
prompt_embeds = prompt_embeds[:, :max_sequence_length]
257+
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
258+
256259
_, seq_len, _ = prompt_embeds.shape
257260
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
258261
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -316,20 +319,6 @@ def check_inputs(
316319
if max_sequence_length is not None and max_sequence_length > 1024:
317320
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
318321

319-
@staticmethod
320-
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
321-
latent_image_ids = torch.zeros(height, width, 3)
322-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
323-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
324-
325-
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
326-
327-
latent_image_ids = latent_image_ids.reshape(
328-
latent_image_id_height * latent_image_id_width, latent_image_id_channels
329-
)
330-
331-
return latent_image_ids.to(device=device, dtype=dtype)
332-
333322
@staticmethod
334323
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
335324
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -402,8 +391,7 @@ def prepare_latents(
402391
shape = (batch_size, 1, num_channels_latents, height, width)
403392

404393
if latents is not None:
405-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
406-
return latents.to(device=device, dtype=dtype), latent_image_ids
394+
return latents.to(device=device, dtype=dtype)
407395

408396
if isinstance(generator, list) and len(generator) != batch_size:
409397
raise ValueError(
@@ -414,9 +402,7 @@ def prepare_latents(
414402
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
415403
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
416404

417-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
418-
419-
return latents, latent_image_ids
405+
return latents
420406

421407
@property
422408
def guidance_scale(self):
@@ -594,7 +580,7 @@ def __call__(
594580

595581
# 4. Prepare latent variables
596582
num_channels_latents = self.transformer.config.in_channels // 4
597-
latents, latent_image_ids = self.prepare_latents(
583+
latents = self.prepare_latents(
598584
batch_size * num_images_per_prompt,
599585
num_channels_latents,
600586
height,
@@ -604,7 +590,7 @@ def __call__(
604590
generator,
605591
latents,
606592
)
607-
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size
593+
img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
608594

609595
# 5. Prepare timesteps
610596
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas

0 commit comments

Comments
 (0)