Skip to content

Commit 6c1a60f

Browse files
committed
Merge branch 'main' into flux-control-lora-training-script
2 parents f188e80 + 43534a8 commit 6c1a60f

19 files changed

+2831
-47
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
4848
- all
4949
- __call__
5050

51+
## StableDiffusionPAGInpaintPipeline
52+
[[autodoc]] StableDiffusionPAGInpaintPipeline
53+
- all
54+
- __call__
55+
5156
## StableDiffusionPAGPipeline
5257
[[autodoc]] StableDiffusionPAGPipeline
5358
- all

examples/cogvideo/train_cogvideox_image_to_video_lora.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
872872
crops_coords=grid_crops_coords,
873873
grid_size=(grid_height, grid_width),
874874
temporal_size=num_frames,
875+
device=device,
875876
)
876877

877-
freqs_cos = freqs_cos.to(device=device)
878-
freqs_sin = freqs_sin.to(device=device)
879878
return freqs_cos, freqs_sin
880879

881880

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
894894
crops_coords=grid_crops_coords,
895895
grid_size=(grid_height, grid_width),
896896
temporal_size=num_frames,
897+
device=device,
897898
)
898899

899-
freqs_cos = freqs_cos.to(device=device)
900-
freqs_sin = freqs_sin.to(device=device)
901900
return freqs_cos, freqs_sin
902901

903902

examples/community/pipeline_flux_rf_inversion.py

Lines changed: 1061 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@
363363
"StableDiffusionLDM3DPipeline",
364364
"StableDiffusionModelEditingPipeline",
365365
"StableDiffusionPAGImg2ImgPipeline",
366+
"StableDiffusionPAGInpaintPipeline",
366367
"StableDiffusionPAGPipeline",
367368
"StableDiffusionPanoramaPipeline",
368369
"StableDiffusionParadigmsPipeline",
@@ -834,6 +835,7 @@
834835
StableDiffusionLDM3DPipeline,
835836
StableDiffusionModelEditingPipeline,
836837
StableDiffusionPAGImg2ImgPipeline,
838+
StableDiffusionPAGInpaintPipeline,
837839
StableDiffusionPAGPipeline,
838840
StableDiffusionPanoramaPipeline,
839841
StableDiffusionParadigmsPipeline,

src/diffusers/image_processor.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, to
236236
`np.ndarray` or `torch.Tensor`:
237237
The denormalized image array.
238238
"""
239-
return (images / 2 + 0.5).clamp(0, 1)
239+
return (images * 0.5 + 0.5).clamp(0, 1)
240240

241241
@staticmethod
242242
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
@@ -537,6 +537,26 @@ def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
537537

538538
return image
539539

540+
def _denormalize_conditionally(
541+
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
542+
) -> torch.Tensor:
543+
r"""
544+
Denormalize a batch of images based on a condition list.
545+
546+
Args:
547+
images (`torch.Tensor`):
548+
The input image tensor.
549+
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
550+
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
551+
value of `do_normalize` in the `VaeImageProcessor` config.
552+
"""
553+
if do_denormalize is None:
554+
return self.denormalize(images) if self.config.do_normalize else images
555+
556+
return torch.stack(
557+
[self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
558+
)
559+
540560
def get_default_height_width(
541561
self,
542562
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
@@ -752,12 +772,7 @@ def postprocess(
752772
if output_type == "latent":
753773
return image
754774

755-
if do_denormalize is None:
756-
do_denormalize = [self.config.do_normalize] * image.shape[0]
757-
758-
image = torch.stack(
759-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
760-
)
775+
image = self._denormalize_conditionally(image, do_denormalize)
761776

762777
if output_type == "pt":
763778
return image
@@ -966,12 +981,7 @@ def postprocess(
966981
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
967982
output_type = "np"
968983

969-
if do_denormalize is None:
970-
do_denormalize = [self.config.do_normalize] * image.shape[0]
971-
972-
image = torch.stack(
973-
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
974-
)
984+
image = self._denormalize_conditionally(image, do_denormalize)
975985

976986
image = self.pt_to_numpy(image)
977987

src/diffusers/loaders/single_file_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
219219
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
220220

221221
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
222-
if original_config:
222+
if original_config is not None:
223223
if "config_mapping_fn" in mapping_functions:
224224
config_mapping_fn = mapping_functions["config_mapping_fn"]
225225
else:
@@ -243,7 +243,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
243243
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
244244
)
245245
else:
246-
if config:
246+
if config is not None:
247247
if isinstance(config, str):
248248
default_pretrained_model_config_name = config
249249
else:
@@ -270,6 +270,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
270270
subfolder=subfolder,
271271
local_files_only=local_files_only,
272272
token=token,
273+
revision=revision,
273274
)
274275
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
275276

src/diffusers/models/embeddings.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
594594
use_real: bool = True,
595595
grid_type: str = "linspace",
596596
max_size: Optional[Tuple[int, int]] = None,
597+
device: Optional[torch.device] = None,
597598
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
598599
"""
599600
RoPE for video tokens with 3D structure.
@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
621622
if grid_type == "linspace":
622623
start, stop = crops_coords
623624
grid_size_h, grid_size_w = grid_size
624-
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
625-
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
626-
grid_t = np.arange(temporal_size, dtype=np.float32)
627-
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
625+
grid_h = torch.linspace(
626+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
627+
)
628+
grid_w = torch.linspace(
629+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
630+
)
631+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
632+
grid_t = torch.linspace(
633+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
634+
)
628635
elif grid_type == "slice":
629636
max_h, max_w = max_size
630637
grid_size_h, grid_size_w = grid_size
631-
grid_h = np.arange(max_h, dtype=np.float32)
632-
grid_w = np.arange(max_w, dtype=np.float32)
633-
grid_t = np.arange(temporal_size, dtype=np.float32)
638+
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
639+
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
640+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
634641
else:
635642
raise ValueError("Invalid value passed for `grid_type`.")
636643

@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
640647
dim_w = embed_dim // 8 * 3
641648

642649
# Temporal frequencies
643-
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
650+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
644651
# Spatial frequencies for height and width
645-
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
646-
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
652+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
653+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
647654

648655
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
649656
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
686693
temporal_size,
687694
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
688695
theta: int = 10000,
696+
device: Optional[torch.device] = None,
689697
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
690698
# TODO(aryan): docs
691699
start, stop = crops_coords
692700
grid_size_h, grid_size_w = grid_size
693701
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
694-
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
695-
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
696-
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
702+
grid_t = torch.linspace(
703+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
704+
)
705+
grid_h = torch.linspace(
706+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
707+
)
708+
grid_w = torch.linspace(
709+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
710+
)
697711

698712
# Compute dimensions for each axis
699713
dim_t = embed_dim // 3

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@
174174
"StableDiffusion3PAGImg2ImgPipeline",
175175
"StableDiffusionPAGPipeline",
176176
"StableDiffusionPAGImg2ImgPipeline",
177+
"StableDiffusionPAGInpaintPipeline",
177178
"StableDiffusionControlNetPAGPipeline",
178179
"StableDiffusionXLPAGPipeline",
179180
"StableDiffusionXLPAGInpaintPipeline",
@@ -595,6 +596,7 @@
595596
StableDiffusionControlNetPAGInpaintPipeline,
596597
StableDiffusionControlNetPAGPipeline,
597598
StableDiffusionPAGImg2ImgPipeline,
599+
StableDiffusionPAGInpaintPipeline,
598600
StableDiffusionPAGPipeline,
599601
StableDiffusionXLControlNetPAGImg2ImgPipeline,
600602
StableDiffusionXLControlNetPAGPipeline,

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -623,20 +623,17 @@ def _prepare_rotary_positional_embeddings(
623623
self.transformer.config.interpolation_scale_h,
624624
self.transformer.config.interpolation_scale_w,
625625
),
626+
device=device,
626627
)
627628

628-
grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long)
629-
grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long)
630-
grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long)
629+
grid_t = grid_t.to(dtype=torch.long)
630+
grid_h = grid_h.to(dtype=torch.long)
631+
grid_w = grid_w.to(dtype=torch.long)
631632

632633
pos = torch.cartesian_prod(grid_t, grid_h, grid_w)
633634
pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous()
634635
grid_t, grid_h, grid_w = pos
635636

636-
freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device))
637-
freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device))
638-
freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device))
639-
640637
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
641638

642639
@property

0 commit comments

Comments
 (0)